• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1#!/usr/bin/env python
2#
3# Copyright 2015 Google Inc.
4#
5# Licensed under the Apache License, Version 2.0 (the "License");
6# you may not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9#     http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an "AS IS" BASIS,
13# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
16
17"""Service registry for apitools."""
18
19import collections
20import logging
21import re
22import textwrap
23
24from apitools.base.py import base_api
25from apitools.gen import util
26
27# We're a code generator. I don't care.
28# pylint:disable=too-many-statements
29
30_MIME_PATTERN_RE = re.compile(r'(?i)[a-z0-9_*-]+/[a-z0-9_*-]+')
31
32
33class ServiceRegistry(object):
34
35    """Registry for service types."""
36
37    def __init__(self, client_info, message_registry,
38                 names, root_package, base_files_package,
39                 unelidable_request_methods):
40        self.__client_info = client_info
41        self.__package = client_info.package
42        self.__names = names
43        self.__service_method_info_map = collections.OrderedDict()
44        self.__message_registry = message_registry
45        self.__root_package = root_package
46        self.__base_files_package = base_files_package
47        self.__unelidable_request_methods = unelidable_request_methods
48        self.__all_scopes = set(self.__client_info.scopes)
49
50    def Validate(self):
51        self.__message_registry.Validate()
52
53    @property
54    def scopes(self):
55        return sorted(list(self.__all_scopes))
56
57    def __GetServiceClassName(self, service_name):
58        return self.__names.ClassName(
59            '%sService' % self.__names.ClassName(service_name))
60
61    def __PrintDocstring(self, printer, method_info, method_name, name):
62        """Print a docstring for a service method."""
63        if method_info.description:
64            description = util.CleanDescription(method_info.description)
65            first_line, newline, remaining = method_info.description.partition(
66                '\n')
67            if not first_line.endswith('.'):
68                first_line = '%s.' % first_line
69            description = '%s%s%s' % (first_line, newline, remaining)
70        else:
71            description = '%s method for the %s service.' % (method_name, name)
72        with printer.CommentContext():
73            printer('r"""%s' % description)
74        printer()
75        printer('Args:')
76        printer('  request: (%s) input message', method_info.request_type_name)
77        printer('  global_params: (StandardQueryParameters, default: None) '
78                'global arguments')
79        if method_info.upload_config:
80            printer('  upload: (Upload, default: None) If present, upload')
81            printer('      this stream with the request.')
82        if method_info.supports_download:
83            printer(
84                '  download: (Download, default: None) If present, download')
85            printer('      data from the request via this stream.')
86        printer('Returns:')
87        printer('  (%s) The response message.', method_info.response_type_name)
88        printer('"""')
89
90    def __WriteSingleService(
91            self, printer, name, method_info_map, client_class_name):
92        printer()
93        class_name = self.__GetServiceClassName(name)
94        printer('class %s(base_api.BaseApiService):', class_name)
95        with printer.Indent():
96            printer('"""Service class for the %s resource."""', name)
97            printer()
98            printer('_NAME = %s', repr(name))
99
100            # Print the configs for the methods first.
101            printer()
102            printer('def __init__(self, client):')
103            with printer.Indent():
104                printer('super(%s.%s, self).__init__(client)',
105                        client_class_name, class_name)
106                printer('self._upload_configs = {')
107                with printer.Indent(indent='    '):
108                    for method_name, method_info in method_info_map.items():
109                        upload_config = method_info.upload_config
110                        if upload_config is not None:
111                            printer(
112                                "'%s': base_api.ApiUploadInfo(", method_name)
113                            with printer.Indent(indent='    '):
114                                attrs = sorted(
115                                    x.name for x in upload_config.all_fields())
116                                for attr in attrs:
117                                    printer('%s=%r,',
118                                            attr, getattr(upload_config, attr))
119                            printer('),')
120                    printer('}')
121
122            # Now write each method in turn.
123            for method_name, method_info in method_info_map.items():
124                printer()
125                params = ['self', 'request', 'global_params=None']
126                if method_info.upload_config:
127                    params.append('upload=None')
128                if method_info.supports_download:
129                    params.append('download=None')
130                printer('def %s(%s):', method_name, ', '.join(params))
131                with printer.Indent():
132                    self.__PrintDocstring(
133                        printer, method_info, method_name, name)
134                    printer("config = self.GetMethodConfig('%s')", method_name)
135                    upload_config = method_info.upload_config
136                    if upload_config is not None:
137                        printer("upload_config = self.GetUploadConfig('%s')",
138                                method_name)
139                    arg_lines = [
140                        'config, request, global_params=global_params']
141                    if method_info.upload_config:
142                        arg_lines.append(
143                            'upload=upload, upload_config=upload_config')
144                    if method_info.supports_download:
145                        arg_lines.append('download=download')
146                    printer('return self._RunMethod(')
147                    with printer.Indent(indent='    '):
148                        for line in arg_lines[:-1]:
149                            printer('%s,', line)
150                        printer('%s)', arg_lines[-1])
151                printer()
152                printer('{0}.method_config = lambda: base_api.ApiMethodInfo('
153                        .format(method_name))
154                with printer.Indent(indent='    '):
155                    method_info = method_info_map[method_name]
156                    attrs = sorted(
157                        x.name for x in method_info.all_fields())
158                    for attr in attrs:
159                        if attr in ('upload_config', 'description'):
160                            continue
161                        value = getattr(method_info, attr)
162                        if value is not None:
163                            printer('%s=%r,', attr, value)
164                printer(')')
165
166    def __WriteProtoServiceDeclaration(self, printer, name, method_info_map):
167        """Write a single service declaration to a proto file."""
168        printer()
169        printer('service %s {', self.__GetServiceClassName(name))
170        with printer.Indent():
171            for method_name, method_info in method_info_map.items():
172                for line in textwrap.wrap(method_info.description,
173                                          printer.CalculateWidth() - 3):
174                    printer('// %s', line)
175                printer('rpc %s (%s) returns (%s);',
176                        method_name,
177                        method_info.request_type_name,
178                        method_info.response_type_name)
179        printer('}')
180
181    def WriteProtoFile(self, printer):
182        """Write the services in this registry to out as proto."""
183        self.Validate()
184        client_info = self.__client_info
185        printer('// Generated services for %s version %s.',
186                client_info.package, client_info.version)
187        printer()
188        printer('syntax = "proto2";')
189        printer('package %s;', self.__package)
190        printer('import "%s";', client_info.messages_proto_file_name)
191        printer()
192        for name, method_info_map in self.__service_method_info_map.items():
193            self.__WriteProtoServiceDeclaration(printer, name, method_info_map)
194
195    def WriteFile(self, printer):
196        """Write the services in this registry to out."""
197        self.Validate()
198        client_info = self.__client_info
199        printer('"""Generated client library for %s version %s."""',
200                client_info.package, client_info.version)
201        printer('# NOTE: This file is autogenerated and should not be edited '
202                'by hand.')
203        printer()
204        printer('from __future__ import absolute_import')
205        printer()
206        printer('from %s import base_api', self.__base_files_package)
207        if self.__root_package:
208            import_prefix = 'from {0} '.format(self.__root_package)
209        else:
210            import_prefix = ''
211        printer('%simport %s as messages', import_prefix,
212                client_info.messages_rule_name)
213        printer()
214        printer()
215        printer('class %s(base_api.BaseApiClient):',
216                client_info.client_class_name)
217        with printer.Indent():
218            printer(
219                '"""Generated client library for service %s version %s."""',
220                client_info.package, client_info.version)
221            printer()
222            printer('MESSAGES_MODULE = messages')
223            printer('BASE_URL = {0!r}'.format(client_info.base_url))
224            printer('MTLS_BASE_URL = {0!r}'.format(client_info.mtls_base_url))
225            printer()
226            printer('_PACKAGE = {0!r}'.format(client_info.package))
227            printer('_SCOPES = {0!r}'.format(
228                client_info.scopes or
229                ['https://www.googleapis.com/auth/userinfo.email']))
230            printer('_VERSION = {0!r}'.format(client_info.version))
231            printer('_CLIENT_ID = {0!r}'.format(client_info.client_id))
232            printer('_CLIENT_SECRET = {0!r}'.format(client_info.client_secret))
233            printer('_USER_AGENT = {0!r}'.format(client_info.user_agent))
234            printer('_CLIENT_CLASS_NAME = {0!r}'.format(
235                client_info.client_class_name))
236            printer('_URL_VERSION = {0!r}'.format(client_info.url_version))
237            printer('_API_KEY = {0!r}'.format(client_info.api_key))
238            printer()
239            printer("def __init__(self, url='', credentials=None,")
240            with printer.Indent(indent='             '):
241                printer('get_credentials=True, http=None, model=None,')
242                printer('log_request=False, log_response=False,')
243                printer('credentials_args=None, default_global_params=None,')
244                printer('additional_http_headers=None, '
245                        'response_encoding=None):')
246            with printer.Indent():
247                printer('"""Create a new %s handle."""', client_info.package)
248                printer('url = url or self.BASE_URL')
249                printer(
250                    'super(%s, self).__init__(', client_info.client_class_name)
251                printer('    url, credentials=credentials,')
252                printer('    get_credentials=get_credentials, http=http, '
253                        'model=model,')
254                printer('    log_request=log_request, '
255                        'log_response=log_response,')
256                printer('    credentials_args=credentials_args,')
257                printer('    default_global_params=default_global_params,')
258                printer('    additional_http_headers=additional_http_headers,')
259                printer('    response_encoding=response_encoding)')
260                for name in self.__service_method_info_map.keys():
261                    printer('self.%s = self.%s(self)',
262                            name, self.__GetServiceClassName(name))
263            for name, method_info in self.__service_method_info_map.items():
264                self.__WriteSingleService(
265                    printer, name, method_info, client_info.client_class_name)
266
267    def __RegisterService(self, service_name, method_info_map):
268        if service_name in self.__service_method_info_map:
269            raise ValueError(
270                'Attempt to re-register descriptor %s' % service_name)
271        self.__service_method_info_map[service_name] = method_info_map
272
273    def __CreateRequestType(self, method_description, body_type=None):
274        """Create a request type for this method."""
275        schema = {}
276        schema['id'] = self.__names.ClassName('%sRequest' % (
277            self.__names.ClassName(method_description['id'], separator='.'),))
278        schema['type'] = 'object'
279        schema['properties'] = collections.OrderedDict()
280        if 'parameterOrder' not in method_description:
281            ordered_parameters = list(method_description.get('parameters', []))
282        else:
283            ordered_parameters = method_description['parameterOrder'][:]
284            for k in method_description['parameters']:
285                if k not in ordered_parameters:
286                    ordered_parameters.append(k)
287        for parameter_name in ordered_parameters:
288            field = dict(method_description['parameters'][parameter_name])
289            if 'type' not in field:
290                raise ValueError('No type found in parameter %s' % field)
291            schema['properties'][parameter_name] = field
292        if body_type is not None:
293            body_field_name = self.__GetRequestField(
294                method_description, body_type)
295            if body_field_name in schema['properties']:
296                raise ValueError('Failed to normalize request resource name')
297            if 'description' not in body_type:
298                body_type['description'] = (
299                    'A %s resource to be passed as the request body.' % (
300                        self.__GetRequestType(body_type),))
301            schema['properties'][body_field_name] = body_type
302        self.__message_registry.AddDescriptorFromSchema(schema['id'], schema)
303        return schema['id']
304
305    def __CreateVoidResponseType(self, method_description):
306        """Create an empty response type."""
307        schema = {}
308        method_name = self.__names.ClassName(
309            method_description['id'], separator='.')
310        schema['id'] = self.__names.ClassName('%sResponse' % method_name)
311        schema['type'] = 'object'
312        schema['description'] = 'An empty %s response.' % method_name
313        self.__message_registry.AddDescriptorFromSchema(schema['id'], schema)
314        return schema['id']
315
316    def __NeedRequestType(self, method_description, request_type):
317        """Determine if this method needs a new request type created."""
318        if not request_type:
319            return True
320        method_id = method_description.get('id', '')
321        if method_id in self.__unelidable_request_methods:
322            return True
323        message = self.__message_registry.LookupDescriptorOrDie(request_type)
324        if message is None:
325            return True
326        field_names = [x.name for x in message.fields]
327        parameters = method_description.get('parameters', {})
328        for param_name, param_info in parameters.items():
329            if (param_info.get('location') != 'path' or
330                    self.__names.CleanName(param_name) not in field_names):
331                break
332        else:
333            return False
334        return True
335
336    def __MaxSizeToInt(self, max_size):
337        """Convert max_size to an int."""
338        size_groups = re.match(r'(?P<size>\d+)(?P<unit>.B)?$', max_size)
339        if size_groups is None:
340            raise ValueError('Could not parse maxSize')
341        size, unit = size_groups.group('size', 'unit')
342        shift = 0
343        if unit is not None:
344            unit_dict = {'KB': 10, 'MB': 20, 'GB': 30, 'TB': 40}
345            shift = unit_dict.get(unit.upper())
346            if shift is None:
347                raise ValueError('Unknown unit %s' % unit)
348        return int(size) * (1 << shift)
349
350    def __ComputeUploadConfig(self, media_upload_config, method_id):
351        """Fill out the upload config for this method."""
352        config = base_api.ApiUploadInfo()
353        if 'maxSize' in media_upload_config:
354            config.max_size = self.__MaxSizeToInt(
355                media_upload_config['maxSize'])
356        if 'accept' not in media_upload_config:
357            logging.warning(
358                'No accept types found for upload configuration in '
359                'method %s, using */*', method_id)
360        config.accept.extend([
361            str(a) for a in media_upload_config.get('accept', '*/*')])
362
363        for accept_pattern in config.accept:
364            if not _MIME_PATTERN_RE.match(accept_pattern):
365                logging.warning('Unexpected MIME type: %s', accept_pattern)
366        protocols = media_upload_config.get('protocols', {})
367        for protocol in ('simple', 'resumable'):
368            media = protocols.get(protocol, {})
369            for attr in ('multipart', 'path'):
370                if attr in media:
371                    setattr(config, '%s_%s' % (protocol, attr), media[attr])
372        return config
373
374    def __ComputeMethodInfo(self, method_description, request, response,
375                            request_field):
376        """Compute the base_api.ApiMethodInfo for this method."""
377        relative_path = self.__names.NormalizeRelativePath(
378            ''.join((self.__client_info.base_path,
379                     method_description['path'])))
380        method_id = method_description['id']
381        ordered_params = []
382        for param_name in method_description.get('parameterOrder', []):
383            param_info = method_description['parameters'][param_name]
384            if param_info.get('required', False):
385                ordered_params.append(param_name)
386        method_info = base_api.ApiMethodInfo(
387            relative_path=relative_path,
388            method_id=method_id,
389            http_method=method_description['httpMethod'],
390            description=util.CleanDescription(
391                method_description.get('description', '')),
392            query_params=[],
393            path_params=[],
394            ordered_params=ordered_params,
395            request_type_name=self.__names.ClassName(request),
396            response_type_name=self.__names.ClassName(response),
397            request_field=request_field,
398        )
399        flat_path = method_description.get('flatPath', None)
400        if flat_path is not None:
401            flat_path = self.__names.NormalizeRelativePath(
402                self.__client_info.base_path + flat_path)
403            if flat_path != relative_path:
404                method_info.flat_path = flat_path
405        if method_description.get('supportsMediaUpload', False):
406            method_info.upload_config = self.__ComputeUploadConfig(
407                method_description.get('mediaUpload'), method_id)
408        method_info.supports_download = method_description.get(
409            'supportsMediaDownload', False)
410        self.__all_scopes.update(method_description.get('scopes', ()))
411        for param, desc in method_description.get('parameters', {}).items():
412            param = self.__names.CleanName(param)
413            location = desc['location']
414            if location == 'query':
415                method_info.query_params.append(param)
416            elif location == 'path':
417                method_info.path_params.append(param)
418            else:
419                raise ValueError(
420                    'Unknown parameter location %s for parameter %s' % (
421                        location, param))
422        method_info.path_params.sort()
423        method_info.query_params.sort()
424        return method_info
425
426    def __BodyFieldName(self, body_type):
427        if body_type is None:
428            return ''
429        return self.__names.FieldName(body_type['$ref'])
430
431    def __GetRequestType(self, body_type):
432        return self.__names.ClassName(body_type.get('$ref'))
433
434    def __GetRequestField(self, method_description, body_type):
435        """Determine the request field for this method."""
436        body_field_name = self.__BodyFieldName(body_type)
437        if body_field_name in method_description.get('parameters', {}):
438            body_field_name = self.__names.FieldName(
439                '%s_resource' % body_field_name)
440        # It's exceedingly unlikely that we'd get two name collisions, which
441        # means it's bound to happen at some point.
442        while body_field_name in method_description.get('parameters', {}):
443            body_field_name = self.__names.FieldName(
444                '%s_body' % body_field_name)
445        return body_field_name
446
447    def AddServiceFromResource(self, service_name, methods):
448        """Add a new service named service_name with the given methods."""
449        service_name = self.__names.CleanName(service_name)
450        method_descriptions = methods.get('methods', {})
451        method_info_map = collections.OrderedDict()
452        items = sorted(method_descriptions.items())
453        for method_name, method_description in items:
454            method_name = self.__names.MethodName(method_name)
455
456            # NOTE: According to the discovery document, if the request or
457            # response is present, it will simply contain a `$ref`.
458            body_type = method_description.get('request')
459            if body_type is None:
460                request_type = None
461            else:
462                request_type = self.__GetRequestType(body_type)
463            if self.__NeedRequestType(method_description, request_type):
464                request = self.__CreateRequestType(
465                    method_description, body_type=body_type)
466                request_field = self.__GetRequestField(
467                    method_description, body_type)
468            else:
469                request = request_type
470                request_field = base_api.REQUEST_IS_BODY
471
472            if 'response' in method_description:
473                response = method_description['response']['$ref']
474            else:
475                response = self.__CreateVoidResponseType(method_description)
476
477            method_info_map[method_name] = self.__ComputeMethodInfo(
478                method_description, request, response, request_field)
479
480        nested_services = methods.get('resources', {})
481        services = sorted(nested_services.items())
482        for subservice_name, submethods in services:
483            new_service_name = '%s_%s' % (service_name, subservice_name)
484            self.AddServiceFromResource(new_service_name, submethods)
485
486        self.__RegisterService(service_name, method_info_map)
487