1# 2# Copyright 2008 Google Inc. All Rights Reserved. 3# 4""" 5This module contains the generic CLI object 6 7High Level Design: 8 9The atest class contains attributes & method generic to all the CLI 10operations. 11 12The class inheritance is shown here using the command 13'atest host create ...' as an example: 14 15atest <-- host <-- host_create <-- site_host_create 16 17Note: The site_<topic>.py and its classes are only needed if you need 18to override the common <topic>.py methods with your site specific ones. 19 20 21High Level Algorithm: 22 231. atest figures out the topic and action from the 2 first arguments 24 on the command line and imports the <topic> (or site_<topic>) 25 module. 26 271. Init 28 The main atest module creates a <topic>_<action> object. The 29 __init__() function is used to setup the parser options, if this 30 <action> has some specific options to add to its <topic>. 31 32 If it exists, the child __init__() method must call its parent 33 class __init__() before adding its own parser arguments. 34 352. Parsing 36 If the child wants to validate the parsing (e.g. make sure that 37 there are hosts in the arguments), or if it wants to check the 38 options it added in its __init__(), it should implement a parse() 39 method. 40 41 The child parser must call its parent parser and gets back the 42 options dictionary and the rest of the command line arguments 43 (leftover). Each level gets to see all the options, but the 44 leftovers can be deleted as they can be consumed by only one 45 object. 46 473. Execution 48 This execute() method is specific to the child and should use the 49 self.execute_rpc() to send commands to the Autotest Front-End. It 50 should return results. 51 524. Output 53 The child output() method is called with the execute() resutls as a 54 parameter. This is child-specific, but should leverage the 55 atest.print_*() methods. 56""" 57 58import optparse 59import os 60import re 61import sys 62import textwrap 63import traceback 64import urllib2 65 66from autotest_lib.cli import rpc 67from autotest_lib.client.common_lib.test_utils import mock 68 69 70# Maps the AFE keys to printable names. 71KEYS_TO_NAMES_EN = {'hostname': 'Host', 72 'platform': 'Platform', 73 'status': 'Status', 74 'locked': 'Locked', 75 'locked_by': 'Locked by', 76 'lock_time': 'Locked time', 77 'lock_reason': 'Lock Reason', 78 'labels': 'Labels', 79 'description': 'Description', 80 'hosts': 'Hosts', 81 'users': 'Users', 82 'id': 'Id', 83 'name': 'Name', 84 'invalid': 'Valid', 85 'login': 'Login', 86 'access_level': 'Access Level', 87 'job_id': 'Job Id', 88 'job_owner': 'Job Owner', 89 'job_name': 'Job Name', 90 'test_type': 'Test Type', 91 'test_class': 'Test Class', 92 'path': 'Path', 93 'owner': 'Owner', 94 'status_counts': 'Status Counts', 95 'hosts_status': 'Host Status', 96 'hosts_selected_status': 'Hosts filtered by Status', 97 'priority': 'Priority', 98 'control_type': 'Control Type', 99 'created_on': 'Created On', 100 'control_file': 'Control File', 101 'only_if_needed': 'Use only if needed', 102 'protection': 'Protection', 103 'run_verify': 'Run verify', 104 'reboot_before': 'Pre-job reboot', 105 'reboot_after': 'Post-job reboot', 106 'experimental': 'Experimental', 107 'synch_count': 'Sync Count', 108 'max_number_of_machines': 'Max. hosts to use', 109 'parse_failed_repair': 'Include failed repair results', 110 'shard': 'Shard', 111 } 112 113# In the failure, tag that will replace the item. 114FAIL_TAG = '<XYZ>' 115 116# Global socket timeout: uploading kernels can take much, 117# much longer than the default 118UPLOAD_SOCKET_TIMEOUT = 60*30 119 120 121# Convertion functions to be called for printing, 122# e.g. to print True/False for booleans. 123def __convert_platform(field): 124 if field is None: 125 return "" 126 elif isinstance(field, int): 127 # Can be 0/1 for False/True 128 return str(bool(field)) 129 else: 130 # Can be a platform name 131 return field 132 133 134def _int_2_bool_string(value): 135 return str(bool(value)) 136 137KEYS_CONVERT = {'locked': _int_2_bool_string, 138 'invalid': lambda flag: str(bool(not flag)), 139 'only_if_needed': _int_2_bool_string, 140 'platform': __convert_platform, 141 'labels': lambda labels: ', '.join(labels), 142 'shards': lambda shard: shard.hostname if shard else ''} 143 144 145def _get_item_key(item, key): 146 """Allow for lookups in nested dictionaries using '.'s within a key.""" 147 if key in item: 148 return item[key] 149 nested_item = item 150 for subkey in key.split('.'): 151 if not subkey: 152 raise ValueError('empty subkey in %r' % key) 153 try: 154 nested_item = nested_item[subkey] 155 except KeyError, e: 156 raise KeyError('%r - looking up key %r in %r' % 157 (e, key, nested_item)) 158 else: 159 return nested_item 160 161 162class CliError(Exception): 163 """Error raised by cli calls. 164 """ 165 pass 166 167 168class item_parse_info(object): 169 """Object keeping track of the parsing options. 170 """ 171 172 def __init__(self, attribute_name, inline_option='', 173 filename_option='', use_leftover=False): 174 """Object keeping track of the parsing options that will 175 make up the content of the atest attribute: 176 attribute_name: the atest attribute name to populate (label) 177 inline_option: the option containing the items (--label) 178 filename_option: the option containing the filename (--blist) 179 use_leftover: whether to add the leftover arguments or not.""" 180 self.attribute_name = attribute_name 181 self.filename_option = filename_option 182 self.inline_option = inline_option 183 self.use_leftover = use_leftover 184 185 186 def get_values(self, options, leftover=[]): 187 """Returns the value for that attribute by accumualting all 188 the values found through the inline option, the parsing of the 189 file and the leftover""" 190 191 def __get_items(input, split_spaces=True): 192 """Splits a string of comma separated items. Escaped commas will not 193 be split. I.e. Splitting 'a, b\,c, d' will yield ['a', 'b,c', 'd']. 194 If split_spaces is set to False spaces will not be split. I.e. 195 Splitting 'a b, c\,d, e' will yield ['a b', 'c,d', 'e']""" 196 197 # Replace escaped slashes with null characters so we don't misparse 198 # proceeding commas. 199 input = input.replace(r'\\', '\0') 200 201 # Split on commas which are not preceded by a slash. 202 if not split_spaces: 203 split = re.split(r'(?<!\\),', input) 204 else: 205 split = re.split(r'(?<!\\),|\s', input) 206 207 # Convert null characters to single slashes and escaped commas to 208 # just plain commas. 209 return (item.strip().replace('\0', '\\').replace(r'\,', ',') for 210 item in split if item.strip()) 211 212 if self.use_leftover: 213 add_on = leftover 214 leftover = [] 215 else: 216 add_on = [] 217 218 # Start with the add_on 219 result = set() 220 for items in add_on: 221 # Don't split on space here because the add-on 222 # may have some spaces (like the job name) 223 result.update(__get_items(items, split_spaces=False)) 224 225 # Process the inline_option, if any 226 try: 227 items = getattr(options, self.inline_option) 228 result.update(__get_items(items)) 229 except (AttributeError, TypeError): 230 pass 231 232 # Process the file list, if any and not empty 233 # The file can contain space and/or comma separated items 234 try: 235 flist = getattr(options, self.filename_option) 236 file_content = [] 237 for line in open(flist).readlines(): 238 file_content += __get_items(line) 239 if len(file_content) == 0: 240 raise CliError("Empty file %s" % flist) 241 result.update(file_content) 242 except (AttributeError, TypeError): 243 pass 244 except IOError: 245 raise CliError("Could not open file %s" % flist) 246 247 return list(result), leftover 248 249 250class atest(object): 251 """Common class for generic processing 252 Should only be instantiated by itself for usage 253 references, otherwise, the <topic> objects should 254 be used.""" 255 msg_topic = ('[acl|host|job|label|shard|test|user|server|' 256 'stable_version]') 257 usage_action = '[action]' 258 msg_items = '' 259 260 def invalid_arg(self, header, follow_up=''): 261 """Fail the command with error that command line has invalid argument. 262 263 @param header: Header of the error message. 264 @param follow_up: Extra error message, default to empty string. 265 """ 266 twrap = textwrap.TextWrapper(initial_indent=' ', 267 subsequent_indent=' ') 268 rest = twrap.fill(follow_up) 269 270 if self.kill_on_failure: 271 self.invalid_syntax(header + rest) 272 else: 273 print >> sys.stderr, header + rest 274 275 276 def invalid_syntax(self, msg): 277 """Fail the command with error that the command line syntax is wrong. 278 279 @param msg: Error message. 280 """ 281 print 282 print >> sys.stderr, msg 283 print 284 print "usage:", 285 print self._get_usage() 286 print 287 sys.exit(1) 288 289 290 def generic_error(self, msg): 291 """Fail the command with a generic error. 292 293 @param msg: Error message. 294 """ 295 if self.debug: 296 traceback.print_exc() 297 print >> sys.stderr, msg 298 sys.exit(1) 299 300 301 def parse_json_exception(self, full_error): 302 """Parses the JSON exception to extract the bad 303 items and returns them 304 This is very kludgy for the moment, but we would need 305 to refactor the exceptions sent from the front end 306 to make this better. 307 308 @param full_error: The complete error message. 309 """ 310 errmsg = str(full_error).split('Traceback')[0].rstrip('\n') 311 parts = errmsg.split(':') 312 # Kludge: If there are 2 colons the last parts contains 313 # the items that failed. 314 if len(parts) != 3: 315 return [] 316 return [item.strip() for item in parts[2].split(',') if item.strip()] 317 318 319 def failure(self, full_error, item=None, what_failed='', fatal=False): 320 """If kill_on_failure, print this error and die, 321 otherwise, queue the error and accumulate all the items 322 that triggered the same error. 323 324 @param full_error: The complete error message. 325 @param item: Name of the actionable item, e.g., hostname. 326 @param what_failed: Name of the failed item. 327 @param fatal: True to exit the program with failure. 328 """ 329 330 if self.debug: 331 errmsg = str(full_error) 332 else: 333 errmsg = str(full_error).split('Traceback')[0].rstrip('\n') 334 335 if self.kill_on_failure or fatal: 336 print >> sys.stderr, "%s\n %s" % (what_failed, errmsg) 337 sys.exit(1) 338 339 # Build a dictionary with the 'what_failed' as keys. The 340 # values are dictionaries with the errmsg as keys and a set 341 # of items as values. 342 # self.failed = 343 # {'Operation delete_host_failed': {'AclAccessViolation: 344 # set('host0', 'host1')}} 345 # Try to gather all the same error messages together, 346 # even if they contain the 'item' 347 if item and item in errmsg: 348 errmsg = errmsg.replace(item, FAIL_TAG) 349 if self.failed.has_key(what_failed): 350 self.failed[what_failed].setdefault(errmsg, set()).add(item) 351 else: 352 self.failed[what_failed] = {errmsg: set([item])} 353 354 355 def show_all_failures(self): 356 """Print all failure information. 357 """ 358 if not self.failed: 359 return 0 360 for what_failed in self.failed.keys(): 361 print >> sys.stderr, what_failed + ':' 362 for (errmsg, items) in self.failed[what_failed].iteritems(): 363 if len(items) == 0: 364 print >> sys.stderr, errmsg 365 elif items == set(['']): 366 print >> sys.stderr, ' ' + errmsg 367 elif len(items) == 1: 368 # Restore the only item 369 if FAIL_TAG in errmsg: 370 errmsg = errmsg.replace(FAIL_TAG, items.pop()) 371 else: 372 errmsg = '%s (%s)' % (errmsg, items.pop()) 373 print >> sys.stderr, ' ' + errmsg 374 else: 375 print >> sys.stderr, ' ' + errmsg + ' with <XYZ> in:' 376 twrap = textwrap.TextWrapper(initial_indent=' ', 377 subsequent_indent=' ') 378 items = list(items) 379 items.sort() 380 print >> sys.stderr, twrap.fill(', '.join(items)) 381 return 1 382 383 384 def __init__(self): 385 """Setup the parser common options""" 386 # Initialized for unit tests. 387 self.afe = None 388 self.failed = {} 389 self.data = {} 390 self.debug = False 391 self.parse_delim = '|' 392 self.kill_on_failure = False 393 self.web_server = '' 394 self.verbose = False 395 self.no_confirmation = False 396 self.topic_parse_info = item_parse_info(attribute_name='not_used') 397 398 self.parser = optparse.OptionParser(self._get_usage()) 399 self.parser.add_option('-g', '--debug', 400 help='Print debugging information', 401 action='store_true', default=False) 402 self.parser.add_option('--kill-on-failure', 403 help='Stop at the first failure', 404 action='store_true', default=False) 405 self.parser.add_option('--parse', 406 help='Print the output using | ' 407 'separated key=value fields', 408 action='store_true', default=False) 409 self.parser.add_option('--parse-delim', 410 help='Delimiter to use to separate the ' 411 'key=value fields', default='|') 412 self.parser.add_option('--no-confirmation', 413 help=('Skip all confirmation in when function ' 414 'require_confirmation is called.'), 415 action='store_true', default=False) 416 self.parser.add_option('-v', '--verbose', 417 action='store_true', default=False) 418 self.parser.add_option('-w', '--web', 419 help='Specify the autotest server ' 420 'to talk to', 421 action='store', type='string', 422 dest='web_server', default=None) 423 424 425 def _get_usage(self): 426 return "atest %s %s [options] %s" % (self.msg_topic.lower(), 427 self.usage_action, 428 self.msg_items) 429 430 431 def backward_compatibility(self, action, argv): 432 """To be overidden by subclass if their syntax changed. 433 434 @param action: Name of the action. 435 @param argv: A list of arguments. 436 """ 437 return action 438 439 440 def parse(self, parse_info=[], req_items=None): 441 """Parse command arguments. 442 443 parse_info is a list of item_parse_info objects. 444 There should only be one use_leftover set to True in the list. 445 446 Also check that the req_items is not empty after parsing. 447 448 @param parse_info: A list of item_parse_info objects. 449 @param req_items: A list of required items. 450 """ 451 (options, leftover) = self.parse_global() 452 453 all_parse_info = parse_info[:] 454 all_parse_info.append(self.topic_parse_info) 455 456 try: 457 for item_parse_info in all_parse_info: 458 values, leftover = item_parse_info.get_values(options, 459 leftover) 460 setattr(self, item_parse_info.attribute_name, values) 461 except CliError, s: 462 self.invalid_syntax(s) 463 464 if (req_items and not getattr(self, req_items, None)): 465 self.invalid_syntax('%s %s requires at least one %s' % 466 (self.msg_topic, 467 self.usage_action, 468 self.msg_topic)) 469 470 return (options, leftover) 471 472 473 def parse_global(self): 474 """Parse the global arguments. 475 476 It consumes what the common object needs to know, and 477 let the children look at all the options. We could 478 remove the options that we have used, but there is no 479 harm in leaving them, and the children may need them 480 in the future. 481 482 Must be called from its children parse()""" 483 (options, leftover) = self.parser.parse_args() 484 # Handle our own options setup in __init__() 485 self.debug = options.debug 486 self.kill_on_failure = options.kill_on_failure 487 488 if options.parse: 489 suffix = '_parse' 490 else: 491 suffix = '_std' 492 for func in ['print_fields', 'print_table', 493 'print_by_ids', 'print_list']: 494 setattr(self, func, getattr(self, func + suffix)) 495 496 self.parse_delim = options.parse_delim 497 498 self.verbose = options.verbose 499 self.no_confirmation = options.no_confirmation 500 self.web_server = options.web_server 501 try: 502 self.afe = rpc.afe_comm(self.web_server) 503 except rpc.AuthError, s: 504 self.failure(str(s), fatal=True) 505 506 return (options, leftover) 507 508 509 def check_and_create_items(self, op_get, op_create, 510 items, **data_create): 511 """Create the items if they don't exist already. 512 513 @param op_get: Name of `get` RPC. 514 @param op_create: Name of `create` RPC. 515 @param items: Actionable items specified in CLI command, e.g., hostname, 516 to be passed to each RPC. 517 @param data_create: Data to be passed to `create` RPC. 518 """ 519 for item in items: 520 ret = self.execute_rpc(op_get, name=item) 521 522 if len(ret) == 0: 523 try: 524 data_create['name'] = item 525 self.execute_rpc(op_create, **data_create) 526 except CliError: 527 continue 528 529 530 def execute_rpc(self, op, item='', **data): 531 """Execute RPC. 532 533 @param op: Name of the RPC. 534 @param item: Actionable item specified in CLI command. 535 @param data: Data to be passed to RPC. 536 """ 537 retry = 2 538 while retry: 539 try: 540 return self.afe.run(op, **data) 541 except urllib2.URLError, err: 542 if hasattr(err, 'reason'): 543 if 'timed out' not in err.reason: 544 self.invalid_syntax('Invalid server name %s: %s' % 545 (self.afe.web_server, err)) 546 if hasattr(err, 'code'): 547 error_parts = [str(err)] 548 if self.debug: 549 error_parts.append(err.read()) # read the response body 550 self.failure('\n\n'.join(error_parts), item=item, 551 what_failed=("Error received from web server")) 552 raise CliError("Error from web server") 553 if self.debug: 554 print 'retrying: %r %d' % (data, retry) 555 retry -= 1 556 if retry == 0: 557 if item: 558 myerr = '%s timed out for %s' % (op, item) 559 else: 560 myerr = '%s timed out' % op 561 self.failure(myerr, item=item, 562 what_failed=("Timed-out contacting " 563 "the Autotest server")) 564 raise CliError("Timed-out contacting the Autotest server") 565 except mock.CheckPlaybackError: 566 raise 567 except Exception, full_error: 568 # There are various exceptions throwns by JSON, 569 # urllib & httplib, so catch them all. 570 self.failure(full_error, item=item, 571 what_failed='Operation %s failed' % op) 572 raise CliError(str(full_error)) 573 574 575 # There is no output() method in the atest object (yet?) 576 # but here are some helper functions to be used by its 577 # children 578 def print_wrapped(self, msg, values): 579 """Print given message and values in wrapped lines unless 580 AUTOTEST_CLI_NO_WRAP is specified in environment variables. 581 582 @param msg: Message to print. 583 @param values: A list of values to print. 584 """ 585 if len(values) == 0: 586 return 587 elif len(values) == 1: 588 print msg + ': ' 589 elif len(values) > 1: 590 if msg.endswith('s'): 591 print msg + ': ' 592 else: 593 print msg + 's: ' 594 595 values.sort() 596 597 if 'AUTOTEST_CLI_NO_WRAP' in os.environ: 598 print '\n'.join(values) 599 return 600 601 twrap = textwrap.TextWrapper(initial_indent='\t', 602 subsequent_indent='\t') 603 print twrap.fill(', '.join(values)) 604 605 606 def __conv_value(self, type, value): 607 return KEYS_CONVERT.get(type, str)(value) 608 609 610 def print_fields_std(self, items, keys, title=None): 611 """Print the keys in each item, one on each line. 612 613 @param items: Items to print. 614 @param keys: Name of the keys to look up each item in items. 615 @param title: Title of the output, default to None. 616 """ 617 if not items: 618 return 619 if title: 620 print title 621 for item in items: 622 for key in keys: 623 print '%s: %s' % (KEYS_TO_NAMES_EN[key], 624 self.__conv_value(key, 625 _get_item_key(item, key))) 626 627 628 def print_fields_parse(self, items, keys, title=None): 629 """Print the keys in each item as comma separated name=value 630 631 @param items: Items to print. 632 @param keys: Name of the keys to look up each item in items. 633 @param title: Title of the output, default to None. 634 """ 635 for item in items: 636 values = ['%s=%s' % (KEYS_TO_NAMES_EN[key], 637 self.__conv_value(key, 638 _get_item_key(item, key))) 639 for key in keys 640 if self.__conv_value(key, 641 _get_item_key(item, key)) != ''] 642 print self.parse_delim.join(values) 643 644 645 def __find_justified_fmt(self, items, keys): 646 """Find the max length for each field. 647 648 @param items: Items to lookup for. 649 @param keys: Name of the keys to look up each item in items. 650 """ 651 lens = {} 652 # Don't justify the last field, otherwise we have blank 653 # lines when the max is overlaps but the current values 654 # are smaller 655 if not items: 656 print "No results" 657 return 658 for key in keys[:-1]: 659 lens[key] = max(len(self.__conv_value(key, 660 _get_item_key(item, key))) 661 for item in items) 662 lens[key] = max(lens[key], len(KEYS_TO_NAMES_EN[key])) 663 lens[keys[-1]] = 0 664 665 return ' '.join(["%%-%ds" % lens[key] for key in keys]) 666 667 668 def print_dict(self, items, title=None, line_before=False): 669 """Print a dictionary. 670 671 @param items: Dictionary to print. 672 @param title: Title of the output, default to None. 673 @param line_before: True to print an empty line before the output, 674 default to False. 675 """ 676 if not items: 677 return 678 if line_before: 679 print 680 print title 681 for key, value in items.items(): 682 print '%s : %s' % (key, value) 683 684 685 def print_table_std(self, items, keys_header, sublist_keys=()): 686 """Print a mix of header and lists in a user readable format. 687 688 The headers are justified, the sublist_keys are wrapped. 689 690 @param items: Items to print. 691 @param keys_header: Header of the keys, use to look up in items. 692 @param sublist_keys: Keys for sublist in each item. 693 """ 694 if not items: 695 return 696 fmt = self.__find_justified_fmt(items, keys_header) 697 header = tuple(KEYS_TO_NAMES_EN[key] for key in keys_header) 698 print fmt % header 699 for item in items: 700 values = tuple(self.__conv_value(key, 701 _get_item_key(item, key)) 702 for key in keys_header) 703 print fmt % values 704 if sublist_keys: 705 for key in sublist_keys: 706 self.print_wrapped(KEYS_TO_NAMES_EN[key], 707 _get_item_key(item, key)) 708 print '\n' 709 710 711 def print_table_parse(self, items, keys_header, sublist_keys=()): 712 """Print a mix of header and lists in a user readable format. 713 714 @param items: Items to print. 715 @param keys_header: Header of the keys, use to look up in items. 716 @param sublist_keys: Keys for sublist in each item. 717 """ 718 for item in items: 719 values = ['%s=%s' % (KEYS_TO_NAMES_EN[key], 720 self.__conv_value(key, _get_item_key(item, key))) 721 for key in keys_header 722 if self.__conv_value(key, 723 _get_item_key(item, key)) != ''] 724 725 if sublist_keys: 726 [values.append('%s=%s'% (KEYS_TO_NAMES_EN[key], 727 ','.join(_get_item_key(item, key)))) 728 for key in sublist_keys 729 if len(_get_item_key(item, key))] 730 731 print self.parse_delim.join(values) 732 733 734 def print_by_ids_std(self, items, title=None, line_before=False): 735 """Prints ID & names of items in a user readable form. 736 737 @param items: Items to print. 738 @param title: Title of the output, default to None. 739 @param line_before: True to print an empty line before the output, 740 default to False. 741 """ 742 if not items: 743 return 744 if line_before: 745 print 746 if title: 747 print title + ':' 748 self.print_table_std(items, keys_header=['id', 'name']) 749 750 751 def print_by_ids_parse(self, items, title=None, line_before=False): 752 """Prints ID & names of items in a parseable format. 753 754 @param items: Items to print. 755 @param title: Title of the output, default to None. 756 @param line_before: True to print an empty line before the output, 757 default to False. 758 """ 759 if not items: 760 return 761 if line_before: 762 print 763 if title: 764 print title + '=', 765 values = [] 766 for item in items: 767 values += ['%s=%s' % (KEYS_TO_NAMES_EN[key], 768 self.__conv_value(key, 769 _get_item_key(item, key))) 770 for key in ['id', 'name'] 771 if self.__conv_value(key, 772 _get_item_key(item, key)) != ''] 773 print self.parse_delim.join(values) 774 775 776 def print_list_std(self, items, key): 777 """Print a wrapped list of results 778 779 @param items: Items to to lookup for given key, could be a nested 780 dictionary. 781 @param key: Name of the key to look up for value. 782 """ 783 if not items: 784 return 785 print ' '.join(_get_item_key(item, key) for item in items) 786 787 788 def print_list_parse(self, items, key): 789 """Print a wrapped list of results. 790 791 @param items: Items to to lookup for given key, could be a nested 792 dictionary. 793 @param key: Name of the key to look up for value. 794 """ 795 if not items: 796 return 797 print '%s=%s' % (KEYS_TO_NAMES_EN[key], 798 ','.join(_get_item_key(item, key) for item in items)) 799 800 801 @staticmethod 802 def prompt_confirmation(message=None): 803 """Prompt a question for user to confirm the action before proceeding. 804 805 @param message: A detailed message to explain possible impact of the 806 action. 807 808 @return: True to proceed or False to abort. 809 """ 810 if message: 811 print message 812 sys.stdout.write('Continue? [y/N] ') 813 read = raw_input().lower() 814 if read == 'y': 815 return True 816 else: 817 print 'User did not confirm. Aborting...' 818 return False 819 820 821 @staticmethod 822 def require_confirmation(message=None): 823 """Decorator to prompt a question for user to confirm action before 824 proceeding. 825 826 If user chooses not to proceed, do not call the function. 827 828 @param message: A detailed message to explain possible impact of the 829 action. 830 831 @return: A decorator wrapper for calling the actual function. 832 """ 833 def deco_require_confirmation(func): 834 """Wrapper for the decorator. 835 836 @param func: Function to be called. 837 838 @return: the actual decorator to call the function. 839 """ 840 def func_require_confirmation(*args, **kwargs): 841 """Decorator to prompt a question for user to confirm. 842 843 @param message: A detailed message to explain possible impact of 844 the action. 845 """ 846 if (args[0].no_confirmation or 847 atest.prompt_confirmation(message)): 848 func(*args, **kwargs) 849 850 return func_require_confirmation 851 return deco_require_confirmation 852