1#!/usr/bin/env python 2# 3# Copyright 2015 Google Inc. 4# 5# Licensed under the Apache License, Version 2.0 (the "License"); 6# you may not use this file except in compliance with the License. 7# You may obtain a copy of the License at 8# 9# http://www.apache.org/licenses/LICENSE-2.0 10# 11# Unless required by applicable law or agreed to in writing, software 12# distributed under the License is distributed on an "AS IS" BASIS, 13# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14# See the License for the specific language governing permissions and 15# limitations under the License. 16 17"""Command registry for apitools.""" 18 19import logging 20import textwrap 21 22from apitools.base.protorpclite import descriptor 23from apitools.base.protorpclite import messages 24from apitools.gen import extended_descriptor 25 26# This is a code generator; we're purposely verbose. 27# pylint:disable=too-many-statements 28 29_VARIANT_TO_FLAG_TYPE_MAP = { 30 messages.Variant.DOUBLE: 'float', 31 messages.Variant.FLOAT: 'float', 32 messages.Variant.INT64: 'string', 33 messages.Variant.UINT64: 'string', 34 messages.Variant.INT32: 'integer', 35 messages.Variant.BOOL: 'boolean', 36 messages.Variant.STRING: 'string', 37 messages.Variant.MESSAGE: 'string', 38 messages.Variant.BYTES: 'string', 39 messages.Variant.UINT32: 'integer', 40 messages.Variant.ENUM: 'enum', 41 messages.Variant.SINT32: 'integer', 42 messages.Variant.SINT64: 'integer', 43} 44 45 46class FlagInfo(messages.Message): 47 48 """Information about a flag and conversion to a message. 49 50 Fields: 51 name: name of this flag. 52 type: type of the flag. 53 description: description of the flag. 54 default: default value for this flag. 55 enum_values: if this flag is an enum, the list of possible 56 values. 57 required: whether or not this flag is required. 58 fv: name of the flag_values object where this flag should 59 be registered. 60 conversion: template for type conversion. 61 special: (boolean, default: False) If True, this flag doesn't 62 correspond to an attribute on the request. 63 """ 64 name = messages.StringField(1) 65 type = messages.StringField(2) 66 description = messages.StringField(3) 67 default = messages.StringField(4) 68 enum_values = messages.StringField(5, repeated=True) 69 required = messages.BooleanField(6, default=False) 70 fv = messages.StringField(7) 71 conversion = messages.StringField(8) 72 special = messages.BooleanField(9, default=False) 73 74 75class ArgInfo(messages.Message): 76 77 """Information about a single positional command argument. 78 79 Fields: 80 name: argument name. 81 description: description of this argument. 82 conversion: template for type conversion. 83 """ 84 name = messages.StringField(1) 85 description = messages.StringField(2) 86 conversion = messages.StringField(3) 87 88 89class CommandInfo(messages.Message): 90 91 """Information about a single command. 92 93 Fields: 94 name: name of this command. 95 class_name: name of the apitools_base.NewCmd class for this command. 96 description: description of this command. 97 flags: list of FlagInfo messages for the command-specific flags. 98 args: list of ArgInfo messages for the positional args. 99 request_type: name of the request type for this command. 100 client_method_path: path from the client object to the method 101 this command is wrapping. 102 """ 103 name = messages.StringField(1) 104 class_name = messages.StringField(2) 105 description = messages.StringField(3) 106 flags = messages.MessageField(FlagInfo, 4, repeated=True) 107 args = messages.MessageField(ArgInfo, 5, repeated=True) 108 request_type = messages.StringField(6) 109 client_method_path = messages.StringField(7) 110 has_upload = messages.BooleanField(8, default=False) 111 has_download = messages.BooleanField(9, default=False) 112 113 114class CommandRegistry(object): 115 116 """Registry for CLI commands.""" 117 118 def __init__(self, package, version, client_info, message_registry, 119 root_package, base_files_package, protorpc_package, names): 120 self.__package = package 121 self.__version = version 122 self.__client_info = client_info 123 self.__names = names 124 self.__message_registry = message_registry 125 self.__root_package = root_package 126 self.__base_files_package = base_files_package 127 self.__protorpc_package = protorpc_package 128 self.__command_list = [] 129 self.__global_flags = [] 130 131 def Validate(self): 132 self.__message_registry.Validate() 133 134 def AddGlobalParameters(self, schema): 135 for field in schema.fields: 136 self.__global_flags.append(self.__FlagInfoFromField(field, schema)) 137 138 def AddCommandForMethod(self, service_name, method_name, method_info, 139 request, _): 140 """Add the given method as a command.""" 141 command_name = self.__GetCommandName(method_info.method_id) 142 calling_path = '%s.%s' % (service_name, method_name) 143 request_type = self.__message_registry.LookupDescriptor(request) 144 description = method_info.description 145 if not description: 146 description = 'Call the %s method.' % method_info.method_id 147 field_map = dict((f.name, f) for f in request_type.fields) 148 args = [] 149 arg_names = [] 150 for field_name in method_info.ordered_params: 151 extended_field = field_map[field_name] 152 name = extended_field.name 153 args.append(ArgInfo( 154 name=name, 155 description=extended_field.description, 156 conversion=self.__GetConversion(extended_field, request_type), 157 )) 158 arg_names.append(name) 159 flags = [] 160 for extended_field in sorted(request_type.fields, 161 key=lambda x: x.name): 162 field = extended_field.field_descriptor 163 if extended_field.name in arg_names: 164 continue 165 if self.__FieldIsRequired(field): 166 logging.warning( 167 'Required field %s not in ordered_params for command %s', 168 extended_field.name, command_name) 169 flags.append(self.__FlagInfoFromField( 170 extended_field, request_type, fv='fv')) 171 if method_info.upload_config: 172 # TODO(craigcitro): Consider adding additional flags to allow 173 # determining the filename from the object metadata. 174 upload_flag_info = FlagInfo( 175 name='upload_filename', type='string', default='', 176 description='Filename to use for upload.', fv='fv', 177 special=True) 178 flags.append(upload_flag_info) 179 mime_description = ( 180 'MIME type to use for the upload. Only needed if ' 181 'the extension on --upload_filename does not determine ' 182 'the correct (or any) MIME type.') 183 mime_type_flag_info = FlagInfo( 184 name='upload_mime_type', type='string', default='', 185 description=mime_description, fv='fv', special=True) 186 flags.append(mime_type_flag_info) 187 if method_info.supports_download: 188 download_flag_info = FlagInfo( 189 name='download_filename', type='string', default='', 190 description='Filename to use for download.', fv='fv', 191 special=True) 192 flags.append(download_flag_info) 193 overwrite_description = ( 194 'If True, overwrite the existing file when downloading.') 195 overwrite_flag_info = FlagInfo( 196 name='overwrite', type='boolean', default='False', 197 description=overwrite_description, fv='fv', special=True) 198 flags.append(overwrite_flag_info) 199 command_info = CommandInfo( 200 name=command_name, 201 class_name=self.__names.ClassName(command_name), 202 description=description, 203 flags=flags, 204 args=args, 205 request_type=request_type.full_name, 206 client_method_path=calling_path, 207 has_upload=bool(method_info.upload_config), 208 has_download=bool(method_info.supports_download) 209 ) 210 self.__command_list.append(command_info) 211 212 def __LookupMessage(self, message, field): 213 message_type = self.__message_registry.LookupDescriptor( 214 '%s.%s' % (message.name, field.type_name)) 215 if message_type is None: 216 message_type = self.__message_registry.LookupDescriptor( 217 field.type_name) 218 return message_type 219 220 def __GetCommandName(self, method_id): 221 command_name = method_id 222 prefix = '%s.' % self.__package 223 if command_name.startswith(prefix): 224 command_name = command_name[len(prefix):] 225 command_name = command_name.replace('.', '_') 226 return command_name 227 228 def __GetConversion(self, extended_field, extended_message): 229 """Returns a template for field type.""" 230 field = extended_field.field_descriptor 231 232 type_name = '' 233 if field.variant in (messages.Variant.MESSAGE, messages.Variant.ENUM): 234 if field.type_name.startswith('apitools.base.protorpclite.'): 235 type_name = field.type_name 236 else: 237 field_message = self.__LookupMessage(extended_message, field) 238 if field_message is None: 239 raise ValueError( 240 'Could not find type for field %s' % field.name) 241 type_name = 'messages.%s' % field_message.full_name 242 243 template = '' 244 if field.variant in (messages.Variant.INT64, messages.Variant.UINT64): 245 template = 'int(%s)' 246 elif field.variant == messages.Variant.MESSAGE: 247 template = 'apitools_base.JsonToMessage(%s, %%s)' % type_name 248 elif field.variant == messages.Variant.ENUM: 249 template = '%s(%%s)' % type_name 250 elif field.variant == messages.Variant.STRING: 251 template = "%s.decode('utf8')" 252 253 if self.__FieldIsRepeated(extended_field.field_descriptor): 254 if template: 255 template = '[%s for x in %%s]' % (template % 'x') 256 257 return template 258 259 def __FieldIsRequired(self, field): 260 return field.label == descriptor.FieldDescriptor.Label.REQUIRED 261 262 def __FieldIsRepeated(self, field): 263 return field.label == descriptor.FieldDescriptor.Label.REPEATED 264 265 def __FlagInfoFromField(self, extended_field, extended_message, fv=''): 266 """Creates FlagInfo object for given field.""" 267 field = extended_field.field_descriptor 268 flag_info = FlagInfo() 269 flag_info.name = str(field.name) 270 # TODO(craigcitro): We should key by variant. 271 flag_info.type = _VARIANT_TO_FLAG_TYPE_MAP[field.variant] 272 flag_info.description = extended_field.description 273 if field.default_value: 274 # TODO(craigcitro): Formatting? 275 flag_info.default = field.default_value 276 if flag_info.type == 'enum': 277 # TODO(craigcitro): Does protorpc do this for us? 278 enum_type = self.__LookupMessage(extended_message, field) 279 if enum_type is None: 280 raise ValueError('Cannot find enum type %s', field.type_name) 281 flag_info.enum_values = [x.name for x in enum_type.values] 282 # Note that this choice is completely arbitrary -- but we only 283 # push the value through if the user specifies it, so this 284 # doesn't hurt anything. 285 if flag_info.default is None: 286 flag_info.default = flag_info.enum_values[0] 287 if self.__FieldIsRequired(field): 288 flag_info.required = True 289 flag_info.fv = fv 290 flag_info.conversion = self.__GetConversion( 291 extended_field, extended_message) 292 return flag_info 293 294 def __PrintFlagDeclarations(self, printer): 295 """Writes out command line flag declarations.""" 296 package = self.__client_info.package 297 function_name = '_Declare%sFlags' % (package[0].upper() + package[1:]) 298 printer() 299 printer() 300 printer('def %s():', function_name) 301 with printer.Indent(): 302 printer('"""Declare global flags in an idempotent way."""') 303 printer("if 'api_endpoint' in flags.FLAGS:") 304 with printer.Indent(): 305 printer('return') 306 printer('flags.DEFINE_string(') 307 with printer.Indent(' '): 308 printer("'api_endpoint',") 309 printer('%r,', self.__client_info.base_url) 310 printer("'URL of the API endpoint to use.',") 311 printer("short_name='%s_url')", self.__package) 312 printer('flags.DEFINE_string(') 313 with printer.Indent(' '): 314 printer("'history_file',") 315 printer('%r,', '~/.%s.%s.history' % 316 (self.__package, self.__version)) 317 printer("'File with interactive shell history.')") 318 printer('flags.DEFINE_multistring(') 319 with printer.Indent(' '): 320 printer("'add_header', [],") 321 printer("'Additional http headers (as key=value strings). '") 322 printer("'Can be specified multiple times.')") 323 printer('flags.DEFINE_string(') 324 with printer.Indent(' '): 325 printer("'service_account_json_keyfile', '',") 326 printer("'Filename for a JSON service account key downloaded'") 327 printer("' from the Developer Console.')") 328 for flag_info in self.__global_flags: 329 self.__PrintFlag(printer, flag_info) 330 printer() 331 printer() 332 printer('FLAGS = flags.FLAGS') 333 printer('apitools_base_cli.DeclareBaseFlags()') 334 printer('%s()', function_name) 335 336 def __PrintGetGlobalParams(self, printer): 337 """Writes out GetGlobalParamsFromFlags function.""" 338 printer('def GetGlobalParamsFromFlags():') 339 with printer.Indent(): 340 printer('"""Return a StandardQueryParameters based on flags."""') 341 printer('result = messages.StandardQueryParameters()') 342 343 for flag_info in self.__global_flags: 344 rhs = 'FLAGS.%s' % flag_info.name 345 if flag_info.conversion: 346 rhs = flag_info.conversion % rhs 347 printer('if FLAGS[%r].present:', flag_info.name) 348 with printer.Indent(): 349 printer('result.%s = %s', flag_info.name, rhs) 350 printer('return result') 351 printer() 352 printer() 353 354 def __PrintGetClient(self, printer): 355 """Writes out GetClientFromFlags function.""" 356 printer('def GetClientFromFlags():') 357 with printer.Indent(): 358 printer('"""Return a client object, configured from flags."""') 359 printer('log_request = FLAGS.log_request or ' 360 'FLAGS.log_request_response') 361 printer('log_response = FLAGS.log_response or ' 362 'FLAGS.log_request_response') 363 printer('api_endpoint = apitools_base.NormalizeApiEndpoint(' 364 'FLAGS.api_endpoint)') 365 printer("additional_http_headers = dict(x.split('=', 1) for x in " 366 "FLAGS.add_header)") 367 printer('credentials_args = {') 368 with printer.Indent(' '): 369 printer("'service_account_json_keyfile': os.path.expanduser(" 370 'FLAGS.service_account_json_keyfile)') 371 printer('}') 372 printer('try:') 373 with printer.Indent(): 374 printer('client = client_lib.%s(', 375 self.__client_info.client_class_name) 376 with printer.Indent(indent=' '): 377 printer('api_endpoint, log_request=log_request,') 378 printer('log_response=log_response,') 379 printer('credentials_args=credentials_args,') 380 printer('additional_http_headers=additional_http_headers)') 381 printer('except apitools_base.CredentialsError as e:') 382 with printer.Indent(): 383 printer("print 'Error creating credentials: %%s' %% e") 384 printer('sys.exit(1)') 385 printer('return client') 386 printer() 387 printer() 388 389 def __PrintCommandDocstring(self, printer, command_info): 390 with printer.CommentContext(): 391 for line in textwrap.wrap('"""%s' % command_info.description, 392 printer.CalculateWidth()): 393 printer(line) 394 extended_descriptor.PrintIndentedDescriptions( 395 printer, command_info.args, 'Args') 396 extended_descriptor.PrintIndentedDescriptions( 397 printer, command_info.flags, 'Flags') 398 printer('"""') 399 400 def __PrintFlag(self, printer, flag_info): 401 """Writes out given flag definition.""" 402 printer('flags.DEFINE_%s(', flag_info.type) 403 with printer.Indent(indent=' '): 404 printer('%r,', flag_info.name) 405 printer('%r,', flag_info.default) 406 if flag_info.type == 'enum': 407 printer('%r,', flag_info.enum_values) 408 409 # TODO(craigcitro): Consider using 'drop_whitespace' elsewhere. 410 description_lines = textwrap.wrap( 411 flag_info.description, 75 - len(printer.indent), 412 drop_whitespace=False) 413 for line in description_lines[:-1]: 414 printer('%r', line) 415 last_line = description_lines[-1] if description_lines else '' 416 printer('%r%s', last_line, ',' if flag_info.fv else ')') 417 if flag_info.fv: 418 printer('flag_values=%s)', flag_info.fv) 419 if flag_info.required: 420 printer('flags.MarkFlagAsRequired(%r)', flag_info.name) 421 422 def __PrintPyShell(self, printer): 423 """Writes out PyShell class.""" 424 printer('class PyShell(appcommands.Cmd):') 425 printer() 426 with printer.Indent(): 427 printer('def Run(self, _):') 428 with printer.Indent(): 429 printer( 430 '"""Run an interactive python shell with the client."""') 431 printer('client = GetClientFromFlags()') 432 printer('params = GetGlobalParamsFromFlags()') 433 printer('for field in params.all_fields():') 434 with printer.Indent(): 435 printer('value = params.get_assigned_value(field.name)') 436 printer('if value != field.default:') 437 with printer.Indent(): 438 printer('client.AddGlobalParam(field.name, value)') 439 printer('banner = """') 440 printer(' == %s interactive console ==' % ( 441 self.__client_info.package)) 442 printer(' client: a %s client' % 443 self.__client_info.package) 444 printer(' apitools_base: base apitools module') 445 printer(' messages: the generated messages module') 446 printer('"""') 447 printer('local_vars = {') 448 with printer.Indent(indent=' '): 449 printer("'apitools_base': apitools_base,") 450 printer("'client': client,") 451 printer("'client_lib': client_lib,") 452 printer("'messages': messages,") 453 printer('}') 454 printer("if platform.system() == 'Linux':") 455 with printer.Indent(): 456 printer('console = apitools_base_cli.ConsoleWithReadline(') 457 with printer.Indent(indent=' '): 458 printer('local_vars, histfile=FLAGS.history_file)') 459 printer('else:') 460 with printer.Indent(): 461 printer('console = code.InteractiveConsole(local_vars)') 462 printer('try:') 463 with printer.Indent(): 464 printer('console.interact(banner)') 465 printer('except SystemExit as e:') 466 with printer.Indent(): 467 printer('return e.code') 468 printer() 469 printer() 470 471 def WriteFile(self, printer): 472 """Write a simple CLI (currently just a stub).""" 473 printer('#!/usr/bin/env python') 474 printer('"""CLI for %s, version %s."""', 475 self.__package, self.__version) 476 printer('# NOTE: This file is autogenerated and should not be edited ' 477 'by hand.') 478 # TODO(craigcitro): Add a build stamp, along with some other 479 # information. 480 printer() 481 printer('import code') 482 printer('import os') 483 printer('import platform') 484 printer('import sys') 485 printer() 486 printer('from %s import message_types', self.__protorpc_package) 487 printer('from %s import messages', self.__protorpc_package) 488 printer() 489 appcommands_import = 'from google.apputils import appcommands' 490 printer(appcommands_import) 491 492 flags_import = 'import gflags as flags' 493 printer(flags_import) 494 printer() 495 printer('import %s as apitools_base', self.__base_files_package) 496 printer('from %s import cli as apitools_base_cli', 497 self.__base_files_package) 498 import_prefix = '' 499 printer('%simport %s as client_lib', 500 import_prefix, self.__client_info.client_rule_name) 501 printer('%simport %s as messages', 502 import_prefix, self.__client_info.messages_rule_name) 503 self.__PrintFlagDeclarations(printer) 504 printer() 505 printer() 506 self.__PrintGetGlobalParams(printer) 507 self.__PrintGetClient(printer) 508 self.__PrintPyShell(printer) 509 self.__PrintCommands(printer) 510 printer('def main(_):') 511 with printer.Indent(): 512 printer("appcommands.AddCmd('pyshell', PyShell)") 513 for command_info in self.__command_list: 514 printer("appcommands.AddCmd('%s', %s)", 515 command_info.name, command_info.class_name) 516 printer() 517 printer('apitools_base_cli.SetupLogger()') 518 # TODO(craigcitro): Just call SetDefaultCommand as soon as 519 # another appcommands release happens and this exists 520 # externally. 521 printer("if hasattr(appcommands, 'SetDefaultCommand'):") 522 with printer.Indent(): 523 printer("appcommands.SetDefaultCommand('pyshell')") 524 printer() 525 printer() 526 printer('run_main = apitools_base_cli.run_main') 527 printer() 528 printer("if __name__ == '__main__':") 529 with printer.Indent(): 530 printer('appcommands.Run()') 531 532 def __PrintCommands(self, printer): 533 """Print all commands in this registry using printer.""" 534 for command_info in self.__command_list: 535 arg_list = [arg_info.name for arg_info in command_info.args] 536 printer( 537 'class %s(apitools_base_cli.NewCmd):', command_info.class_name) 538 with printer.Indent(): 539 printer('"""Command wrapping %s."""', 540 command_info.client_method_path) 541 printer() 542 printer('usage = """%s%s%s"""', 543 command_info.name, 544 ' ' if arg_list else '', 545 ' '.join('<%s>' % argname for argname in arg_list)) 546 printer() 547 printer('def __init__(self, name, fv):') 548 with printer.Indent(): 549 printer('super(%s, self).__init__(name, fv)', 550 command_info.class_name) 551 for flag in command_info.flags: 552 self.__PrintFlag(printer, flag) 553 printer() 554 printer('def RunWithArgs(%s):', ', '.join(['self'] + arg_list)) 555 with printer.Indent(): 556 self.__PrintCommandDocstring(printer, command_info) 557 printer('client = GetClientFromFlags()') 558 printer('global_params = GetGlobalParamsFromFlags()') 559 printer( 560 'request = messages.%s(', command_info.request_type) 561 with printer.Indent(indent=' '): 562 for arg in command_info.args: 563 rhs = arg.name 564 if arg.conversion: 565 rhs = arg.conversion % arg.name 566 printer('%s=%s,', arg.name, rhs) 567 printer(')') 568 for flag_info in command_info.flags: 569 if flag_info.special: 570 continue 571 rhs = 'FLAGS.%s' % flag_info.name 572 if flag_info.conversion: 573 rhs = flag_info.conversion % rhs 574 printer('if FLAGS[%r].present:', flag_info.name) 575 with printer.Indent(): 576 printer('request.%s = %s', flag_info.name, rhs) 577 call_args = ['request', 'global_params=global_params'] 578 if command_info.has_upload: 579 call_args.append('upload=upload') 580 printer('upload = None') 581 printer('if FLAGS.upload_filename:') 582 with printer.Indent(): 583 printer('upload = apitools_base.Upload.FromFile(') 584 printer(' FLAGS.upload_filename, ' 585 'FLAGS.upload_mime_type,') 586 printer(' progress_callback=' 587 'apitools_base.UploadProgressPrinter,') 588 printer(' finish_callback=' 589 'apitools_base.UploadCompletePrinter)') 590 if command_info.has_download: 591 call_args.append('download=download') 592 printer('download = None') 593 printer('if FLAGS.download_filename:') 594 with printer.Indent(): 595 printer('download = apitools_base.Download.' 596 'FromFile(FLAGS.download_filename, ' 597 'overwrite=FLAGS.overwrite,') 598 printer(' progress_callback=' 599 'apitools_base.DownloadProgressPrinter,') 600 printer(' finish_callback=' 601 'apitools_base.DownloadCompletePrinter)') 602 printer( 603 'result = client.%s(', command_info.client_method_path) 604 with printer.Indent(indent=' '): 605 printer('%s)', ', '.join(call_args)) 606 printer('print apitools_base_cli.FormatOutput(result)') 607 printer() 608 printer() 609