• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1#!/usr/bin/env python
2#
3# Copyright 2015 Google Inc.
4#
5# Licensed under the Apache License, Version 2.0 (the "License");
6# you may not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9#     http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an "AS IS" BASIS,
13# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
16
17"""Command registry for apitools."""
18
19import logging
20import textwrap
21
22from apitools.base.protorpclite import descriptor
23from apitools.base.protorpclite import messages
24from apitools.gen import extended_descriptor
25
26# This is a code generator; we're purposely verbose.
27# pylint:disable=too-many-statements
28
29_VARIANT_TO_FLAG_TYPE_MAP = {
30    messages.Variant.DOUBLE: 'float',
31    messages.Variant.FLOAT: 'float',
32    messages.Variant.INT64: 'string',
33    messages.Variant.UINT64: 'string',
34    messages.Variant.INT32: 'integer',
35    messages.Variant.BOOL: 'boolean',
36    messages.Variant.STRING: 'string',
37    messages.Variant.MESSAGE: 'string',
38    messages.Variant.BYTES: 'string',
39    messages.Variant.UINT32: 'integer',
40    messages.Variant.ENUM: 'enum',
41    messages.Variant.SINT32: 'integer',
42    messages.Variant.SINT64: 'integer',
43}
44
45
46class FlagInfo(messages.Message):
47
48    """Information about a flag and conversion to a message.
49
50    Fields:
51      name: name of this flag.
52      type: type of the flag.
53      description: description of the flag.
54      default: default value for this flag.
55      enum_values: if this flag is an enum, the list of possible
56          values.
57      required: whether or not this flag is required.
58      fv: name of the flag_values object where this flag should
59          be registered.
60      conversion: template for type conversion.
61      special: (boolean, default: False) If True, this flag doesn't
62          correspond to an attribute on the request.
63    """
64    name = messages.StringField(1)
65    type = messages.StringField(2)
66    description = messages.StringField(3)
67    default = messages.StringField(4)
68    enum_values = messages.StringField(5, repeated=True)
69    required = messages.BooleanField(6, default=False)
70    fv = messages.StringField(7)
71    conversion = messages.StringField(8)
72    special = messages.BooleanField(9, default=False)
73
74
75class ArgInfo(messages.Message):
76
77    """Information about a single positional command argument.
78
79    Fields:
80      name: argument name.
81      description: description of this argument.
82      conversion: template for type conversion.
83    """
84    name = messages.StringField(1)
85    description = messages.StringField(2)
86    conversion = messages.StringField(3)
87
88
89class CommandInfo(messages.Message):
90
91    """Information about a single command.
92
93    Fields:
94      name: name of this command.
95      class_name: name of the apitools_base.NewCmd class for this command.
96      description: description of this command.
97      flags: list of FlagInfo messages for the command-specific flags.
98      args: list of ArgInfo messages for the positional args.
99      request_type: name of the request type for this command.
100      client_method_path: path from the client object to the method
101          this command is wrapping.
102    """
103    name = messages.StringField(1)
104    class_name = messages.StringField(2)
105    description = messages.StringField(3)
106    flags = messages.MessageField(FlagInfo, 4, repeated=True)
107    args = messages.MessageField(ArgInfo, 5, repeated=True)
108    request_type = messages.StringField(6)
109    client_method_path = messages.StringField(7)
110    has_upload = messages.BooleanField(8, default=False)
111    has_download = messages.BooleanField(9, default=False)
112
113
114class CommandRegistry(object):
115
116    """Registry for CLI commands."""
117
118    def __init__(self, package, version, client_info, message_registry,
119                 root_package, base_files_package, protorpc_package, names):
120        self.__package = package
121        self.__version = version
122        self.__client_info = client_info
123        self.__names = names
124        self.__message_registry = message_registry
125        self.__root_package = root_package
126        self.__base_files_package = base_files_package
127        self.__protorpc_package = protorpc_package
128        self.__command_list = []
129        self.__global_flags = []
130
131    def Validate(self):
132        self.__message_registry.Validate()
133
134    def AddGlobalParameters(self, schema):
135        for field in schema.fields:
136            self.__global_flags.append(self.__FlagInfoFromField(field, schema))
137
138    def AddCommandForMethod(self, service_name, method_name, method_info,
139                            request, _):
140        """Add the given method as a command."""
141        command_name = self.__GetCommandName(method_info.method_id)
142        calling_path = '%s.%s' % (service_name, method_name)
143        request_type = self.__message_registry.LookupDescriptor(request)
144        description = method_info.description
145        if not description:
146            description = 'Call the %s method.' % method_info.method_id
147        field_map = dict((f.name, f) for f in request_type.fields)
148        args = []
149        arg_names = []
150        for field_name in method_info.ordered_params:
151            extended_field = field_map[field_name]
152            name = extended_field.name
153            args.append(ArgInfo(
154                name=name,
155                description=extended_field.description,
156                conversion=self.__GetConversion(extended_field, request_type),
157            ))
158            arg_names.append(name)
159        flags = []
160        for extended_field in sorted(request_type.fields,
161                                     key=lambda x: x.name):
162            field = extended_field.field_descriptor
163            if extended_field.name in arg_names:
164                continue
165            if self.__FieldIsRequired(field):
166                logging.warning(
167                    'Required field %s not in ordered_params for command %s',
168                    extended_field.name, command_name)
169            flags.append(self.__FlagInfoFromField(
170                extended_field, request_type, fv='fv'))
171        if method_info.upload_config:
172            # TODO(craigcitro): Consider adding additional flags to allow
173            # determining the filename from the object metadata.
174            upload_flag_info = FlagInfo(
175                name='upload_filename', type='string', default='',
176                description='Filename to use for upload.', fv='fv',
177                special=True)
178            flags.append(upload_flag_info)
179            mime_description = (
180                'MIME type to use for the upload. Only needed if '
181                'the extension on --upload_filename does not determine '
182                'the correct (or any) MIME type.')
183            mime_type_flag_info = FlagInfo(
184                name='upload_mime_type', type='string', default='',
185                description=mime_description, fv='fv', special=True)
186            flags.append(mime_type_flag_info)
187        if method_info.supports_download:
188            download_flag_info = FlagInfo(
189                name='download_filename', type='string', default='',
190                description='Filename to use for download.', fv='fv',
191                special=True)
192            flags.append(download_flag_info)
193            overwrite_description = (
194                'If True, overwrite the existing file when downloading.')
195            overwrite_flag_info = FlagInfo(
196                name='overwrite', type='boolean', default='False',
197                description=overwrite_description, fv='fv', special=True)
198            flags.append(overwrite_flag_info)
199        command_info = CommandInfo(
200            name=command_name,
201            class_name=self.__names.ClassName(command_name),
202            description=description,
203            flags=flags,
204            args=args,
205            request_type=request_type.full_name,
206            client_method_path=calling_path,
207            has_upload=bool(method_info.upload_config),
208            has_download=bool(method_info.supports_download)
209        )
210        self.__command_list.append(command_info)
211
212    def __LookupMessage(self, message, field):
213        message_type = self.__message_registry.LookupDescriptor(
214            '%s.%s' % (message.name, field.type_name))
215        if message_type is None:
216            message_type = self.__message_registry.LookupDescriptor(
217                field.type_name)
218        return message_type
219
220    def __GetCommandName(self, method_id):
221        command_name = method_id
222        prefix = '%s.' % self.__package
223        if command_name.startswith(prefix):
224            command_name = command_name[len(prefix):]
225        command_name = command_name.replace('.', '_')
226        return command_name
227
228    def __GetConversion(self, extended_field, extended_message):
229        """Returns a template for field type."""
230        field = extended_field.field_descriptor
231
232        type_name = ''
233        if field.variant in (messages.Variant.MESSAGE, messages.Variant.ENUM):
234            if field.type_name.startswith('apitools.base.protorpclite.'):
235                type_name = field.type_name
236            else:
237                field_message = self.__LookupMessage(extended_message, field)
238                if field_message is None:
239                    raise ValueError(
240                        'Could not find type for field %s' % field.name)
241                type_name = 'messages.%s' % field_message.full_name
242
243        template = ''
244        if field.variant in (messages.Variant.INT64, messages.Variant.UINT64):
245            template = 'int(%s)'
246        elif field.variant == messages.Variant.MESSAGE:
247            template = 'apitools_base.JsonToMessage(%s, %%s)' % type_name
248        elif field.variant == messages.Variant.ENUM:
249            template = '%s(%%s)' % type_name
250        elif field.variant == messages.Variant.STRING:
251            template = "%s.decode('utf8')"
252
253        if self.__FieldIsRepeated(extended_field.field_descriptor):
254            if template:
255                template = '[%s for x in %%s]' % (template % 'x')
256
257        return template
258
259    def __FieldIsRequired(self, field):
260        return field.label == descriptor.FieldDescriptor.Label.REQUIRED
261
262    def __FieldIsRepeated(self, field):
263        return field.label == descriptor.FieldDescriptor.Label.REPEATED
264
265    def __FlagInfoFromField(self, extended_field, extended_message, fv=''):
266        """Creates FlagInfo object for given field."""
267        field = extended_field.field_descriptor
268        flag_info = FlagInfo()
269        flag_info.name = str(field.name)
270        # TODO(craigcitro): We should key by variant.
271        flag_info.type = _VARIANT_TO_FLAG_TYPE_MAP[field.variant]
272        flag_info.description = extended_field.description
273        if field.default_value:
274            # TODO(craigcitro): Formatting?
275            flag_info.default = field.default_value
276        if flag_info.type == 'enum':
277            # TODO(craigcitro): Does protorpc do this for us?
278            enum_type = self.__LookupMessage(extended_message, field)
279            if enum_type is None:
280                raise ValueError('Cannot find enum type %s', field.type_name)
281            flag_info.enum_values = [x.name for x in enum_type.values]
282            # Note that this choice is completely arbitrary -- but we only
283            # push the value through if the user specifies it, so this
284            # doesn't hurt anything.
285            if flag_info.default is None:
286                flag_info.default = flag_info.enum_values[0]
287        if self.__FieldIsRequired(field):
288            flag_info.required = True
289        flag_info.fv = fv
290        flag_info.conversion = self.__GetConversion(
291            extended_field, extended_message)
292        return flag_info
293
294    def __PrintFlagDeclarations(self, printer):
295        """Writes out command line flag declarations."""
296        package = self.__client_info.package
297        function_name = '_Declare%sFlags' % (package[0].upper() + package[1:])
298        printer()
299        printer()
300        printer('def %s():', function_name)
301        with printer.Indent():
302            printer('"""Declare global flags in an idempotent way."""')
303            printer("if 'api_endpoint' in flags.FLAGS:")
304            with printer.Indent():
305                printer('return')
306            printer('flags.DEFINE_string(')
307            with printer.Indent('    '):
308                printer("'api_endpoint',")
309                printer('%r,', self.__client_info.base_url)
310                printer("'URL of the API endpoint to use.',")
311                printer("short_name='%s_url')", self.__package)
312            printer('flags.DEFINE_string(')
313            with printer.Indent('    '):
314                printer("'history_file',")
315                printer('%r,', '~/.%s.%s.history' %
316                        (self.__package, self.__version))
317                printer("'File with interactive shell history.')")
318            printer('flags.DEFINE_multistring(')
319            with printer.Indent('    '):
320                printer("'add_header', [],")
321                printer("'Additional http headers (as key=value strings). '")
322                printer("'Can be specified multiple times.')")
323            printer('flags.DEFINE_string(')
324            with printer.Indent('    '):
325                printer("'service_account_json_keyfile', '',")
326                printer("'Filename for a JSON service account key downloaded'")
327                printer("' from the Developer Console.')")
328            for flag_info in self.__global_flags:
329                self.__PrintFlag(printer, flag_info)
330        printer()
331        printer()
332        printer('FLAGS = flags.FLAGS')
333        printer('apitools_base_cli.DeclareBaseFlags()')
334        printer('%s()', function_name)
335
336    def __PrintGetGlobalParams(self, printer):
337        """Writes out GetGlobalParamsFromFlags function."""
338        printer('def GetGlobalParamsFromFlags():')
339        with printer.Indent():
340            printer('"""Return a StandardQueryParameters based on flags."""')
341            printer('result = messages.StandardQueryParameters()')
342
343            for flag_info in self.__global_flags:
344                rhs = 'FLAGS.%s' % flag_info.name
345                if flag_info.conversion:
346                    rhs = flag_info.conversion % rhs
347                printer('if FLAGS[%r].present:', flag_info.name)
348                with printer.Indent():
349                    printer('result.%s = %s', flag_info.name, rhs)
350            printer('return result')
351        printer()
352        printer()
353
354    def __PrintGetClient(self, printer):
355        """Writes out GetClientFromFlags function."""
356        printer('def GetClientFromFlags():')
357        with printer.Indent():
358            printer('"""Return a client object, configured from flags."""')
359            printer('log_request = FLAGS.log_request or '
360                    'FLAGS.log_request_response')
361            printer('log_response = FLAGS.log_response or '
362                    'FLAGS.log_request_response')
363            printer('api_endpoint = apitools_base.NormalizeApiEndpoint('
364                    'FLAGS.api_endpoint)')
365            printer("additional_http_headers = dict(x.split('=', 1) for x in "
366                    "FLAGS.add_header)")
367            printer('credentials_args = {')
368            with printer.Indent('    '):
369                printer("'service_account_json_keyfile': os.path.expanduser("
370                        'FLAGS.service_account_json_keyfile)')
371            printer('}')
372            printer('try:')
373            with printer.Indent():
374                printer('client = client_lib.%s(',
375                        self.__client_info.client_class_name)
376                with printer.Indent(indent='    '):
377                    printer('api_endpoint, log_request=log_request,')
378                    printer('log_response=log_response,')
379                    printer('credentials_args=credentials_args,')
380                    printer('additional_http_headers=additional_http_headers)')
381            printer('except apitools_base.CredentialsError as e:')
382            with printer.Indent():
383                printer("print 'Error creating credentials: %%s' %% e")
384                printer('sys.exit(1)')
385            printer('return client')
386        printer()
387        printer()
388
389    def __PrintCommandDocstring(self, printer, command_info):
390        with printer.CommentContext():
391            for line in textwrap.wrap('"""%s' % command_info.description,
392                                      printer.CalculateWidth()):
393                printer(line)
394            extended_descriptor.PrintIndentedDescriptions(
395                printer, command_info.args, 'Args')
396            extended_descriptor.PrintIndentedDescriptions(
397                printer, command_info.flags, 'Flags')
398            printer('"""')
399
400    def __PrintFlag(self, printer, flag_info):
401        """Writes out given flag definition."""
402        printer('flags.DEFINE_%s(', flag_info.type)
403        with printer.Indent(indent='    '):
404            printer('%r,', flag_info.name)
405            printer('%r,', flag_info.default)
406            if flag_info.type == 'enum':
407                printer('%r,', flag_info.enum_values)
408
409            # TODO(craigcitro): Consider using 'drop_whitespace' elsewhere.
410            description_lines = textwrap.wrap(
411                flag_info.description, 75 - len(printer.indent),
412                drop_whitespace=False)
413            for line in description_lines[:-1]:
414                printer('%r', line)
415            last_line = description_lines[-1] if description_lines else ''
416            printer('%r%s', last_line, ',' if flag_info.fv else ')')
417            if flag_info.fv:
418                printer('flag_values=%s)', flag_info.fv)
419        if flag_info.required:
420            printer('flags.MarkFlagAsRequired(%r)', flag_info.name)
421
422    def __PrintPyShell(self, printer):
423        """Writes out PyShell class."""
424        printer('class PyShell(appcommands.Cmd):')
425        printer()
426        with printer.Indent():
427            printer('def Run(self, _):')
428            with printer.Indent():
429                printer(
430                    '"""Run an interactive python shell with the client."""')
431                printer('client = GetClientFromFlags()')
432                printer('params = GetGlobalParamsFromFlags()')
433                printer('for field in params.all_fields():')
434                with printer.Indent():
435                    printer('value = params.get_assigned_value(field.name)')
436                    printer('if value != field.default:')
437                    with printer.Indent():
438                        printer('client.AddGlobalParam(field.name, value)')
439                printer('banner = """')
440                printer('       == %s interactive console ==' % (
441                    self.__client_info.package))
442                printer('             client: a %s client' %
443                        self.__client_info.package)
444                printer('      apitools_base: base apitools module')
445                printer('     messages: the generated messages module')
446                printer('"""')
447                printer('local_vars = {')
448                with printer.Indent(indent='    '):
449                    printer("'apitools_base': apitools_base,")
450                    printer("'client': client,")
451                    printer("'client_lib': client_lib,")
452                    printer("'messages': messages,")
453                printer('}')
454                printer("if platform.system() == 'Linux':")
455                with printer.Indent():
456                    printer('console = apitools_base_cli.ConsoleWithReadline(')
457                    with printer.Indent(indent='    '):
458                        printer('local_vars, histfile=FLAGS.history_file)')
459                printer('else:')
460                with printer.Indent():
461                    printer('console = code.InteractiveConsole(local_vars)')
462                printer('try:')
463                with printer.Indent():
464                    printer('console.interact(banner)')
465                printer('except SystemExit as e:')
466                with printer.Indent():
467                    printer('return e.code')
468        printer()
469        printer()
470
471    def WriteFile(self, printer):
472        """Write a simple CLI (currently just a stub)."""
473        printer('#!/usr/bin/env python')
474        printer('"""CLI for %s, version %s."""',
475                self.__package, self.__version)
476        printer('# NOTE: This file is autogenerated and should not be edited '
477                'by hand.')
478        # TODO(craigcitro): Add a build stamp, along with some other
479        # information.
480        printer()
481        printer('import code')
482        printer('import os')
483        printer('import platform')
484        printer('import sys')
485        printer()
486        printer('from %s import message_types', self.__protorpc_package)
487        printer('from %s import messages', self.__protorpc_package)
488        printer()
489        appcommands_import = 'from google.apputils import appcommands'
490        printer(appcommands_import)
491
492        flags_import = 'import gflags as flags'
493        printer(flags_import)
494        printer()
495        printer('import %s as apitools_base', self.__base_files_package)
496        printer('from %s import cli as apitools_base_cli',
497                self.__base_files_package)
498        import_prefix = ''
499        printer('%simport %s as client_lib',
500                import_prefix, self.__client_info.client_rule_name)
501        printer('%simport %s as messages',
502                import_prefix, self.__client_info.messages_rule_name)
503        self.__PrintFlagDeclarations(printer)
504        printer()
505        printer()
506        self.__PrintGetGlobalParams(printer)
507        self.__PrintGetClient(printer)
508        self.__PrintPyShell(printer)
509        self.__PrintCommands(printer)
510        printer('def main(_):')
511        with printer.Indent():
512            printer("appcommands.AddCmd('pyshell', PyShell)")
513            for command_info in self.__command_list:
514                printer("appcommands.AddCmd('%s', %s)",
515                        command_info.name, command_info.class_name)
516            printer()
517            printer('apitools_base_cli.SetupLogger()')
518            # TODO(craigcitro): Just call SetDefaultCommand as soon as
519            # another appcommands release happens and this exists
520            # externally.
521            printer("if hasattr(appcommands, 'SetDefaultCommand'):")
522            with printer.Indent():
523                printer("appcommands.SetDefaultCommand('pyshell')")
524        printer()
525        printer()
526        printer('run_main = apitools_base_cli.run_main')
527        printer()
528        printer("if __name__ == '__main__':")
529        with printer.Indent():
530            printer('appcommands.Run()')
531
532    def __PrintCommands(self, printer):
533        """Print all commands in this registry using printer."""
534        for command_info in self.__command_list:
535            arg_list = [arg_info.name for arg_info in command_info.args]
536            printer(
537                'class %s(apitools_base_cli.NewCmd):', command_info.class_name)
538            with printer.Indent():
539                printer('"""Command wrapping %s."""',
540                        command_info.client_method_path)
541                printer()
542                printer('usage = """%s%s%s"""',
543                        command_info.name,
544                        ' ' if arg_list else '',
545                        ' '.join('<%s>' % argname for argname in arg_list))
546                printer()
547                printer('def __init__(self, name, fv):')
548                with printer.Indent():
549                    printer('super(%s, self).__init__(name, fv)',
550                            command_info.class_name)
551                    for flag in command_info.flags:
552                        self.__PrintFlag(printer, flag)
553                printer()
554                printer('def RunWithArgs(%s):', ', '.join(['self'] + arg_list))
555                with printer.Indent():
556                    self.__PrintCommandDocstring(printer, command_info)
557                    printer('client = GetClientFromFlags()')
558                    printer('global_params = GetGlobalParamsFromFlags()')
559                    printer(
560                        'request = messages.%s(', command_info.request_type)
561                    with printer.Indent(indent='    '):
562                        for arg in command_info.args:
563                            rhs = arg.name
564                            if arg.conversion:
565                                rhs = arg.conversion % arg.name
566                            printer('%s=%s,', arg.name, rhs)
567                        printer(')')
568                    for flag_info in command_info.flags:
569                        if flag_info.special:
570                            continue
571                        rhs = 'FLAGS.%s' % flag_info.name
572                        if flag_info.conversion:
573                            rhs = flag_info.conversion % rhs
574                        printer('if FLAGS[%r].present:', flag_info.name)
575                        with printer.Indent():
576                            printer('request.%s = %s', flag_info.name, rhs)
577                    call_args = ['request', 'global_params=global_params']
578                    if command_info.has_upload:
579                        call_args.append('upload=upload')
580                        printer('upload = None')
581                        printer('if FLAGS.upload_filename:')
582                        with printer.Indent():
583                            printer('upload = apitools_base.Upload.FromFile(')
584                            printer('    FLAGS.upload_filename, '
585                                    'FLAGS.upload_mime_type,')
586                            printer('    progress_callback='
587                                    'apitools_base.UploadProgressPrinter,')
588                            printer('    finish_callback='
589                                    'apitools_base.UploadCompletePrinter)')
590                    if command_info.has_download:
591                        call_args.append('download=download')
592                        printer('download = None')
593                        printer('if FLAGS.download_filename:')
594                        with printer.Indent():
595                            printer('download = apitools_base.Download.'
596                                    'FromFile(FLAGS.download_filename, '
597                                    'overwrite=FLAGS.overwrite,')
598                            printer('    progress_callback='
599                                    'apitools_base.DownloadProgressPrinter,')
600                            printer('    finish_callback='
601                                    'apitools_base.DownloadCompletePrinter)')
602                    printer(
603                        'result = client.%s(', command_info.client_method_path)
604                    with printer.Indent(indent='    '):
605                        printer('%s)', ', '.join(call_args))
606                    printer('print apitools_base_cli.FormatOutput(result)')
607            printer()
608            printer()
609