1# pylint: disable=missing-docstring 2""" 3Utility functions for rpc_interface.py. We keep them in a separate file so that 4only RPC interface functions go into that file. 5""" 6 7__author__ = 'showard@google.com (Steve Howard)' 8 9import collections 10import datetime 11from functools import wraps 12import inspect 13import os 14import sys 15import django.db.utils 16import django.http 17 18from autotest_lib.frontend import thread_local 19from autotest_lib.frontend.afe import models, model_logic 20from autotest_lib.client.common_lib import control_data, error 21from autotest_lib.client.common_lib import global_config 22from autotest_lib.client.common_lib import time_utils 23from autotest_lib.client.common_lib.cros import dev_server 24from autotest_lib.server import utils as server_utils 25from autotest_lib.server.cros import provision 26from autotest_lib.server.cros.dynamic_suite import frontend_wrappers 27 28NULL_DATETIME = datetime.datetime.max 29NULL_DATE = datetime.date.max 30DUPLICATE_KEY_MSG = 'Duplicate entry' 31 32def prepare_for_serialization(objects): 33 """ 34 Prepare Python objects to be returned via RPC. 35 @param objects: objects to be prepared. 36 """ 37 if (isinstance(objects, list) and len(objects) and 38 isinstance(objects[0], dict) and 'id' in objects[0]): 39 objects = _gather_unique_dicts(objects) 40 return _prepare_data(objects) 41 42 43def prepare_rows_as_nested_dicts(query, nested_dict_column_names): 44 """ 45 Prepare a Django query to be returned via RPC as a sequence of nested 46 dictionaries. 47 48 @param query - A Django model query object with a select_related() method. 49 @param nested_dict_column_names - A list of column/attribute names for the 50 rows returned by query to expand into nested dictionaries using 51 their get_object_dict() method when not None. 52 53 @returns An list suitable to returned in an RPC. 54 """ 55 all_dicts = [] 56 for row in query.select_related(): 57 row_dict = row.get_object_dict() 58 for column in nested_dict_column_names: 59 if row_dict[column] is not None: 60 row_dict[column] = getattr(row, column).get_object_dict() 61 all_dicts.append(row_dict) 62 return prepare_for_serialization(all_dicts) 63 64 65def _prepare_data(data): 66 """ 67 Recursively process data structures, performing necessary type 68 conversions to values in data to allow for RPC serialization: 69 -convert datetimes to strings 70 -convert tuples and sets to lists 71 """ 72 if isinstance(data, dict): 73 new_data = {} 74 for key, value in data.iteritems(): 75 new_data[key] = _prepare_data(value) 76 return new_data 77 elif (isinstance(data, list) or isinstance(data, tuple) or 78 isinstance(data, set)): 79 return [_prepare_data(item) for item in data] 80 elif isinstance(data, datetime.date): 81 if data is NULL_DATETIME or data is NULL_DATE: 82 return None 83 return str(data) 84 else: 85 return data 86 87 88def fetchall_as_list_of_dicts(cursor): 89 """ 90 Converts each row in the cursor to a dictionary so that values can be read 91 by using the column name. 92 @param cursor: The database cursor to read from. 93 @returns: A list of each row in the cursor as a dictionary. 94 """ 95 desc = cursor.description 96 return [ dict(zip([col[0] for col in desc], row)) 97 for row in cursor.fetchall() ] 98 99 100def raw_http_response(response_data, content_type=None): 101 response = django.http.HttpResponse(response_data, mimetype=content_type) 102 response['Content-length'] = str(len(response.content)) 103 return response 104 105 106def _gather_unique_dicts(dict_iterable): 107 """\ 108 Pick out unique objects (by ID) from an iterable of object dicts. 109 """ 110 objects = collections.OrderedDict() 111 for obj in dict_iterable: 112 objects.setdefault(obj['id'], obj) 113 return objects.values() 114 115 116def extra_job_status_filters(not_yet_run=False, running=False, finished=False): 117 """\ 118 Generate a SQL WHERE clause for job status filtering, and return it in 119 a dict of keyword args to pass to query.extra(). 120 * not_yet_run: all HQEs are Queued 121 * finished: all HQEs are complete 122 * running: everything else 123 """ 124 if not (not_yet_run or running or finished): 125 return {} 126 not_queued = ('(SELECT job_id FROM afe_host_queue_entries ' 127 'WHERE status != "%s")' 128 % models.HostQueueEntry.Status.QUEUED) 129 not_finished = ('(SELECT job_id FROM afe_host_queue_entries ' 130 'WHERE not complete)') 131 132 where = [] 133 if not_yet_run: 134 where.append('id NOT IN ' + not_queued) 135 if running: 136 where.append('(id IN %s) AND (id IN %s)' % (not_queued, not_finished)) 137 if finished: 138 where.append('id NOT IN ' + not_finished) 139 return {'where': [' OR '.join(['(%s)' % x for x in where])]} 140 141 142def extra_job_type_filters(extra_args, suite=False, 143 sub=False, standalone=False): 144 """\ 145 Generate a SQL WHERE clause for job status filtering, and return it in 146 a dict of keyword args to pass to query.extra(). 147 148 param extra_args: a dict of existing extra_args. 149 150 No more than one of the parameters should be passed as True: 151 * suite: job which is parent of other jobs 152 * sub: job with a parent job 153 * standalone: job with no child or parent jobs 154 """ 155 assert not ((suite and sub) or 156 (suite and standalone) or 157 (sub and standalone)), ('Cannot specify more than one ' 158 'filter to this function') 159 160 where = extra_args.get('where', []) 161 parent_job_id = ('DISTINCT parent_job_id') 162 child_job_id = ('id') 163 filter_common = ('(SELECT %s FROM afe_jobs ' 164 'WHERE parent_job_id IS NOT NULL)') 165 166 if suite: 167 where.append('id IN ' + filter_common % parent_job_id) 168 elif sub: 169 where.append('id IN ' + filter_common % child_job_id) 170 elif standalone: 171 where.append('NOT EXISTS (SELECT 1 from afe_jobs AS sub_query ' 172 'WHERE parent_job_id IS NOT NULL' 173 ' AND (sub_query.parent_job_id=afe_jobs.id' 174 ' OR sub_query.id=afe_jobs.id))') 175 else: 176 return extra_args 177 178 extra_args['where'] = where 179 return extra_args 180 181 182 183def extra_host_filters(multiple_labels=()): 184 """\ 185 Generate SQL WHERE clauses for matching hosts in an intersection of 186 labels. 187 """ 188 extra_args = {} 189 where_str = ('afe_hosts.id in (select host_id from afe_hosts_labels ' 190 'where label_id=%s)') 191 extra_args['where'] = [where_str] * len(multiple_labels) 192 extra_args['params'] = [models.Label.smart_get(label).id 193 for label in multiple_labels] 194 return extra_args 195 196 197def get_host_query(multiple_labels, exclude_only_if_needed_labels, 198 valid_only, filter_data): 199 if valid_only: 200 query = models.Host.valid_objects.all() 201 else: 202 query = models.Host.objects.all() 203 204 if exclude_only_if_needed_labels: 205 only_if_needed_labels = models.Label.valid_objects.filter( 206 only_if_needed=True) 207 if only_if_needed_labels.count() > 0: 208 only_if_needed_ids = ','.join( 209 str(label['id']) 210 for label in only_if_needed_labels.values('id')) 211 query = models.Host.objects.add_join( 212 query, 'afe_hosts_labels', join_key='host_id', 213 join_condition=('afe_hosts_labels_exclude_OIN.label_id IN (%s)' 214 % only_if_needed_ids), 215 suffix='_exclude_OIN', exclude=True) 216 try: 217 assert 'extra_args' not in filter_data 218 filter_data['extra_args'] = extra_host_filters(multiple_labels) 219 return models.Host.query_objects(filter_data, initial_query=query) 220 except models.Label.DoesNotExist: 221 return models.Host.objects.none() 222 223 224class InconsistencyException(Exception): 225 'Raised when a list of objects does not have a consistent value' 226 227 228def get_consistent_value(objects, field): 229 if not objects: 230 # well a list of nothing is consistent 231 return None 232 233 value = getattr(objects[0], field) 234 for obj in objects: 235 this_value = getattr(obj, field) 236 if this_value != value: 237 raise InconsistencyException(objects[0], obj) 238 return value 239 240 241def afe_test_dict_to_test_object(test_dict): 242 if not isinstance(test_dict, dict): 243 return test_dict 244 245 numerized_dict = {} 246 for key, value in test_dict.iteritems(): 247 try: 248 numerized_dict[key] = int(value) 249 except (ValueError, TypeError): 250 numerized_dict[key] = value 251 252 return type('TestObject', (object,), numerized_dict) 253 254 255def _check_is_server_test(test_type): 256 """Checks if the test type is a server test. 257 258 @param test_type The test type in enum integer or string. 259 260 @returns A boolean to identify if the test type is server test. 261 """ 262 if test_type is not None: 263 if isinstance(test_type, basestring): 264 try: 265 test_type = control_data.CONTROL_TYPE.get_value(test_type) 266 except AttributeError: 267 return False 268 return (test_type == control_data.CONTROL_TYPE.SERVER) 269 return False 270 271 272def prepare_generate_control_file(tests, profilers, db_tests=True): 273 if db_tests: 274 test_objects = [models.Test.smart_get(test) for test in tests] 275 else: 276 test_objects = [afe_test_dict_to_test_object(test) for test in tests] 277 278 profiler_objects = [models.Profiler.smart_get(profiler) 279 for profiler in profilers] 280 # ensure tests are all the same type 281 try: 282 test_type = get_consistent_value(test_objects, 'test_type') 283 except InconsistencyException, exc: 284 test1, test2 = exc.args 285 raise model_logic.ValidationError( 286 {'tests' : 'You cannot run both test_suites and server-side ' 287 'tests together (tests %s and %s differ' % ( 288 test1.name, test2.name)}) 289 290 is_server = _check_is_server_test(test_type) 291 if test_objects: 292 synch_count = max(test.sync_count for test in test_objects) 293 else: 294 synch_count = 1 295 296 if db_tests: 297 dependencies = set(label.name for label 298 in models.Label.objects.filter(test__in=test_objects)) 299 else: 300 dependencies = reduce( 301 set.union, [set(test.dependencies) for test in test_objects]) 302 303 cf_info = dict(is_server=is_server, synch_count=synch_count, 304 dependencies=list(dependencies)) 305 return cf_info, test_objects, profiler_objects 306 307 308def check_job_dependencies(host_objects, job_dependencies): 309 """ 310 Check that a set of machines satisfies a job's dependencies. 311 host_objects: list of models.Host objects 312 job_dependencies: list of names of labels 313 """ 314 # check that hosts satisfy dependencies 315 host_ids = [host.id for host in host_objects] 316 hosts_in_job = models.Host.objects.filter(id__in=host_ids) 317 ok_hosts = hosts_in_job 318 for index, dependency in enumerate(job_dependencies): 319 if not provision.is_for_special_action(dependency): 320 ok_hosts = ok_hosts.filter(labels__name=dependency) 321 failing_hosts = (set(host.hostname for host in host_objects) - 322 set(host.hostname for host in ok_hosts)) 323 if failing_hosts: 324 raise model_logic.ValidationError( 325 {'hosts' : 'Host(s) failed to meet job dependencies (' + 326 (', '.join(job_dependencies)) + '): ' + 327 (', '.join(failing_hosts))}) 328 329 330def check_job_metahost_dependencies(metahost_objects, job_dependencies): 331 """ 332 Check that at least one machine within the metahost spec satisfies the job's 333 dependencies. 334 335 @param metahost_objects A list of label objects representing the metahosts. 336 @param job_dependencies A list of strings of the required label names. 337 @raises NoEligibleHostException If a metahost cannot run the job. 338 """ 339 for metahost in metahost_objects: 340 hosts = models.Host.objects.filter(labels=metahost) 341 for label_name in job_dependencies: 342 if not provision.is_for_special_action(label_name): 343 hosts = hosts.filter(labels__name=label_name) 344 if not any(hosts): 345 raise error.NoEligibleHostException("No hosts within %s satisfy %s." 346 % (metahost.name, ', '.join(job_dependencies))) 347 348 349def _execution_key_for(host_queue_entry): 350 return (host_queue_entry.job.id, host_queue_entry.execution_subdir) 351 352 353def check_abort_synchronous_jobs(host_queue_entries): 354 # ensure user isn't aborting part of a synchronous autoserv execution 355 count_per_execution = {} 356 for queue_entry in host_queue_entries: 357 key = _execution_key_for(queue_entry) 358 count_per_execution.setdefault(key, 0) 359 count_per_execution[key] += 1 360 361 for queue_entry in host_queue_entries: 362 if not queue_entry.execution_subdir: 363 continue 364 execution_count = count_per_execution[_execution_key_for(queue_entry)] 365 if execution_count < queue_entry.job.synch_count: 366 raise model_logic.ValidationError( 367 {'' : 'You cannot abort part of a synchronous job execution ' 368 '(%d/%s), %d included, %d expected' 369 % (queue_entry.job.id, queue_entry.execution_subdir, 370 execution_count, queue_entry.job.synch_count)}) 371 372 373def check_modify_host(update_data): 374 """ 375 Sanity check modify_host* requests. 376 377 @param update_data: A dictionary with the changes to make to a host 378 or hosts. 379 """ 380 # Only the scheduler (monitor_db) is allowed to modify Host status. 381 # Otherwise race conditions happen as a hosts state is changed out from 382 # beneath tasks being run on a host. 383 if 'status' in update_data: 384 raise model_logic.ValidationError({ 385 'status': 'Host status can not be modified by the frontend.'}) 386 387 388def check_modify_host_locking(host, update_data): 389 """ 390 Checks when locking/unlocking has been requested if the host is already 391 locked/unlocked. 392 393 @param host: models.Host object to be modified 394 @param update_data: A dictionary with the changes to make to the host. 395 """ 396 locked = update_data.get('locked', None) 397 lock_reason = update_data.get('lock_reason', None) 398 if locked is not None: 399 if locked and host.locked: 400 raise model_logic.ValidationError({ 401 'locked': 'Host %s already locked by %s on %s.' % 402 (host.hostname, host.locked_by, host.lock_time)}) 403 if not locked and not host.locked: 404 raise model_logic.ValidationError({ 405 'locked': 'Host %s already unlocked.' % host.hostname}) 406 if locked and not lock_reason and not host.locked: 407 raise model_logic.ValidationError({ 408 'locked': 'Please provide a reason for locking Host %s' % 409 host.hostname}) 410 411 412def get_motd(): 413 dirname = os.path.dirname(__file__) 414 filename = os.path.join(dirname, "..", "..", "motd.txt") 415 text = '' 416 try: 417 fp = open(filename, "r") 418 try: 419 text = fp.read() 420 finally: 421 fp.close() 422 except: 423 pass 424 425 return text 426 427 428def _get_metahost_counts(metahost_objects): 429 metahost_counts = {} 430 for metahost in metahost_objects: 431 metahost_counts.setdefault(metahost, 0) 432 metahost_counts[metahost] += 1 433 return metahost_counts 434 435 436def get_job_info(job, preserve_metahosts=False, queue_entry_filter_data=None): 437 hosts = [] 438 one_time_hosts = [] 439 meta_hosts = [] 440 hostless = False 441 442 queue_entries = job.hostqueueentry_set.all() 443 if queue_entry_filter_data: 444 queue_entries = models.HostQueueEntry.query_objects( 445 queue_entry_filter_data, initial_query=queue_entries) 446 447 for queue_entry in queue_entries: 448 if (queue_entry.host and (preserve_metahosts or 449 not queue_entry.meta_host)): 450 if queue_entry.deleted: 451 continue 452 if queue_entry.host.invalid: 453 one_time_hosts.append(queue_entry.host) 454 else: 455 hosts.append(queue_entry.host) 456 elif queue_entry.meta_host: 457 meta_hosts.append(queue_entry.meta_host) 458 else: 459 hostless = True 460 461 meta_host_counts = _get_metahost_counts(meta_hosts) 462 463 info = dict(dependencies=[label.name for label 464 in job.dependency_labels.all()], 465 hosts=hosts, 466 meta_hosts=meta_hosts, 467 meta_host_counts=meta_host_counts, 468 one_time_hosts=one_time_hosts, 469 hostless=hostless) 470 return info 471 472 473def check_for_duplicate_hosts(host_objects): 474 host_counts = collections.Counter(host_objects) 475 duplicate_hostnames = {host.hostname 476 for host, count in host_counts.iteritems() 477 if count > 1} 478 if duplicate_hostnames: 479 raise model_logic.ValidationError( 480 {'hosts' : 'Duplicate hosts: %s' 481 % ', '.join(duplicate_hostnames)}) 482 483 484def create_new_job(owner, options, host_objects, metahost_objects): 485 all_host_objects = host_objects + metahost_objects 486 dependencies = options.get('dependencies', []) 487 synch_count = options.get('synch_count') 488 489 if synch_count is not None and synch_count > len(all_host_objects): 490 raise model_logic.ValidationError( 491 {'hosts': 492 'only %d hosts provided for job with synch_count = %d' % 493 (len(all_host_objects), synch_count)}) 494 495 check_for_duplicate_hosts(host_objects) 496 497 for label_name in dependencies: 498 if provision.is_for_special_action(label_name): 499 # TODO: We could save a few queries 500 # if we had a bulk ensure-label-exists function, which used 501 # a bulk .get() call. The win is probably very small. 502 _ensure_label_exists(label_name) 503 504 # This only checks targeted hosts, not hosts eligible due to the metahost 505 check_job_dependencies(host_objects, dependencies) 506 check_job_metahost_dependencies(metahost_objects, dependencies) 507 508 options['dependencies'] = list( 509 models.Label.objects.filter(name__in=dependencies)) 510 511 job = models.Job.create(owner=owner, options=options, 512 hosts=all_host_objects) 513 job.queue(all_host_objects, 514 is_template=options.get('is_template', False)) 515 return job.id 516 517 518def _ensure_label_exists(name): 519 """ 520 Ensure that a label called |name| exists in the Django models. 521 522 This function is to be called from within afe rpcs only, as an 523 alternative to server.cros.provision.ensure_label_exists(...). It works 524 by Django model manipulation, rather than by making another create_label 525 rpc call. 526 527 @param name: the label to check for/create. 528 @raises ValidationError: There was an error in the response that was 529 not because the label already existed. 530 @returns True is a label was created, False otherwise. 531 """ 532 # Make sure this function is not called on shards but only on master. 533 assert not server_utils.is_shard() 534 try: 535 models.Label.objects.get(name=name) 536 except models.Label.DoesNotExist: 537 try: 538 new_label = models.Label.objects.create(name=name) 539 new_label.save() 540 return True 541 except django.db.utils.IntegrityError as e: 542 # It is possible that another suite/test already 543 # created the label between the check and save. 544 if DUPLICATE_KEY_MSG in str(e): 545 return False 546 else: 547 raise 548 return False 549 550 551def find_platform(host): 552 """ 553 Figure out the platform name for the given host 554 object. If none, the return value for either will be None. 555 556 @returns platform name for the given host. 557 """ 558 platforms = [label.name for label in host.label_list if label.platform] 559 if not platforms: 560 platform = None 561 else: 562 platform = platforms[0] 563 if len(platforms) > 1: 564 raise ValueError('Host %s has more than one platform: %s' % 565 (host.hostname, ', '.join(platforms))) 566 return platform 567 568 569# support for get_host_queue_entries_and_special_tasks() 570 571def _common_entry_to_dict(entry, type, job_dict, exec_path, status, started_on): 572 return dict(type=type, 573 host=entry['host'], 574 job=job_dict, 575 execution_path=exec_path, 576 status=status, 577 started_on=started_on, 578 id=str(entry['id']) + type, 579 oid=entry['id']) 580 581 582def _special_task_to_dict(task, queue_entries): 583 """Transforms a special task dictionary to another form of dictionary. 584 585 @param task Special task as a dictionary type 586 @param queue_entries Host queue entries as a list of dictionaries. 587 588 @return Transformed dictionary for a special task. 589 """ 590 job_dict = None 591 if task['queue_entry']: 592 # Scan queue_entries to get the job detail info. 593 for qentry in queue_entries: 594 if task['queue_entry']['id'] == qentry['id']: 595 job_dict = qentry['job'] 596 break 597 # If not found, get it from DB. 598 if job_dict is None: 599 job = models.Job.objects.get(id=task['queue_entry']['job']) 600 job_dict = job.get_object_dict() 601 602 exec_path = server_utils.get_special_task_exec_path( 603 task['host']['hostname'], task['id'], task['task'], 604 time_utils.time_string_to_datetime(task['time_requested'])) 605 status = server_utils.get_special_task_status( 606 task['is_complete'], task['success'], task['is_active']) 607 return _common_entry_to_dict(task, task['task'], job_dict, 608 exec_path, status, task['time_started']) 609 610 611def _queue_entry_to_dict(queue_entry): 612 job_dict = queue_entry['job'] 613 tag = server_utils.get_job_tag(job_dict['id'], job_dict['owner']) 614 exec_path = server_utils.get_hqe_exec_path(tag, 615 queue_entry['execution_subdir']) 616 return _common_entry_to_dict(queue_entry, 'Job', job_dict, exec_path, 617 queue_entry['status'], queue_entry['started_on']) 618 619 620def prepare_host_queue_entries_and_special_tasks(interleaved_entries, 621 queue_entries): 622 """ 623 Prepare for serialization the interleaved entries of host queue entries 624 and special tasks. 625 Each element in the entries is a dictionary type. 626 The special task dictionary has only a job id for a job and lacks 627 the detail of the job while the host queue entry dictionary has. 628 queue_entries is used to look up the job detail info. 629 630 @param interleaved_entries Host queue entries and special tasks as a list 631 of dictionaries. 632 @param queue_entries Host queue entries as a list of dictionaries. 633 634 @return A post-processed list of dictionaries that is to be serialized. 635 """ 636 dict_list = [] 637 for e in interleaved_entries: 638 # Distinguish the two mixed entries based on the existence of 639 # the key "task". If an entry has the key, the entry is for 640 # special task. Otherwise, host queue entry. 641 if 'task' in e: 642 dict_list.append(_special_task_to_dict(e, queue_entries)) 643 else: 644 dict_list.append(_queue_entry_to_dict(e)) 645 return prepare_for_serialization(dict_list) 646 647 648def _compute_next_job_for_tasks(queue_entries, special_tasks): 649 """ 650 For each task, try to figure out the next job that ran after that task. 651 This is done using two pieces of information: 652 * if the task has a queue entry, we can use that entry's job ID. 653 * if the task has a time_started, we can try to compare that against the 654 started_on field of queue_entries. this isn't guaranteed to work perfectly 655 since queue_entries may also have null started_on values. 656 * if the task has neither, or if use of time_started fails, just use the 657 last computed job ID. 658 659 @param queue_entries Host queue entries as a list of dictionaries. 660 @param special_tasks Special tasks as a list of dictionaries. 661 """ 662 next_job_id = None # most recently computed next job 663 hqe_index = 0 # index for scanning by started_on times 664 for task in special_tasks: 665 if task['queue_entry']: 666 next_job_id = task['queue_entry']['job'] 667 elif task['time_started'] is not None: 668 for queue_entry in queue_entries[hqe_index:]: 669 if queue_entry['started_on'] is None: 670 continue 671 t1 = time_utils.time_string_to_datetime( 672 queue_entry['started_on']) 673 t2 = time_utils.time_string_to_datetime(task['time_started']) 674 if t1 < t2: 675 break 676 next_job_id = queue_entry['job']['id'] 677 678 task['next_job_id'] = next_job_id 679 680 # advance hqe_index to just after next_job_id 681 if next_job_id is not None: 682 for queue_entry in queue_entries[hqe_index:]: 683 if queue_entry['job']['id'] < next_job_id: 684 break 685 hqe_index += 1 686 687 688def interleave_entries(queue_entries, special_tasks): 689 """ 690 Both lists should be ordered by descending ID. 691 """ 692 _compute_next_job_for_tasks(queue_entries, special_tasks) 693 694 # start with all special tasks that've run since the last job 695 interleaved_entries = [] 696 for task in special_tasks: 697 if task['next_job_id'] is not None: 698 break 699 interleaved_entries.append(task) 700 701 # now interleave queue entries with the remaining special tasks 702 special_task_index = len(interleaved_entries) 703 for queue_entry in queue_entries: 704 interleaved_entries.append(queue_entry) 705 # add all tasks that ran between this job and the previous one 706 for task in special_tasks[special_task_index:]: 707 if task['next_job_id'] < queue_entry['job']['id']: 708 break 709 interleaved_entries.append(task) 710 special_task_index += 1 711 712 return interleaved_entries 713 714 715def bucket_hosts_by_shard(host_objs, rpc_hostnames=False): 716 """Figure out which hosts are on which shards. 717 718 @param host_objs: A list of host objects. 719 @param rpc_hostnames: If True, the rpc_hostnames of a shard are returned 720 instead of the 'real' shard hostnames. This only matters for testing 721 environments. 722 723 @return: A map of shard hostname: list of hosts on the shard. 724 """ 725 shard_host_map = collections.defaultdict(list) 726 for host in host_objs: 727 if host.shard: 728 shard_name = (host.shard.rpc_hostname() if rpc_hostnames 729 else host.shard.hostname) 730 shard_host_map[shard_name].append(host.hostname) 731 return shard_host_map 732 733 734def create_job_common( 735 name, 736 priority, 737 control_type, 738 control_file=None, 739 hosts=(), 740 meta_hosts=(), 741 one_time_hosts=(), 742 synch_count=None, 743 is_template=False, 744 timeout=None, 745 timeout_mins=None, 746 max_runtime_mins=None, 747 run_verify=True, 748 email_list='', 749 dependencies=(), 750 reboot_before=None, 751 reboot_after=None, 752 parse_failed_repair=None, 753 hostless=False, 754 keyvals=None, 755 drone_set=None, 756 parent_job_id=None, 757 test_retry=0, 758 run_reset=True, 759 require_ssp=None): 760 #pylint: disable-msg=C0111 761 """ 762 Common code between creating "standard" jobs and creating parameterized jobs 763 """ 764 # input validation 765 host_args_passed = any((hosts, meta_hosts, one_time_hosts)) 766 if hostless: 767 if host_args_passed: 768 raise model_logic.ValidationError({ 769 'hostless': 'Hostless jobs cannot include any hosts!'}) 770 if control_type != control_data.CONTROL_TYPE_NAMES.SERVER: 771 raise model_logic.ValidationError({ 772 'control_type': 'Hostless jobs cannot use client-side ' 773 'control files'}) 774 elif not host_args_passed: 775 raise model_logic.ValidationError({ 776 'arguments' : "For host jobs, you must pass at least one of" 777 " 'hosts', 'meta_hosts', 'one_time_hosts'." 778 }) 779 label_objects = list(models.Label.objects.filter(name__in=meta_hosts)) 780 781 # convert hostnames & meta hosts to host/label objects 782 host_objects = models.Host.smart_get_bulk(hosts) 783 _validate_host_job_sharding(host_objects) 784 for host in one_time_hosts: 785 this_host = models.Host.create_one_time_host(host) 786 host_objects.append(this_host) 787 788 metahost_objects = [] 789 meta_host_labels_by_name = {label.name: label for label in label_objects} 790 for label_name in meta_hosts: 791 if label_name in meta_host_labels_by_name: 792 metahost_objects.append(meta_host_labels_by_name[label_name]) 793 else: 794 raise model_logic.ValidationError( 795 {'meta_hosts' : 'Label "%s" not found' % label_name}) 796 797 options = dict(name=name, 798 priority=priority, 799 control_file=control_file, 800 control_type=control_type, 801 is_template=is_template, 802 timeout=timeout, 803 timeout_mins=timeout_mins, 804 max_runtime_mins=max_runtime_mins, 805 synch_count=synch_count, 806 run_verify=run_verify, 807 email_list=email_list, 808 dependencies=dependencies, 809 reboot_before=reboot_before, 810 reboot_after=reboot_after, 811 parse_failed_repair=parse_failed_repair, 812 keyvals=keyvals, 813 drone_set=drone_set, 814 parent_job_id=parent_job_id, 815 test_retry=test_retry, 816 run_reset=run_reset, 817 require_ssp=require_ssp) 818 819 return create_new_job(owner=models.User.current_user().login, 820 options=options, 821 host_objects=host_objects, 822 metahost_objects=metahost_objects) 823 824 825def _validate_host_job_sharding(host_objects): 826 """Check that the hosts obey job sharding rules.""" 827 if not (server_utils.is_shard() 828 or _allowed_hosts_for_master_job(host_objects)): 829 shard_host_map = bucket_hosts_by_shard(host_objects) 830 raise ValueError( 831 'The following hosts are on shard(s), please create ' 832 'seperate jobs for hosts on each shard: %s ' % 833 shard_host_map) 834 835 836def _allowed_hosts_for_master_job(host_objects): 837 """Check that the hosts are allowed for a job on master.""" 838 # We disallow the following jobs on master: 839 # num_shards > 1: this is a job spanning across multiple shards. 840 # num_shards == 1 but number of hosts on shard is less 841 # than total number of hosts: this is a job that spans across 842 # one shard and the master. 843 shard_host_map = bucket_hosts_by_shard(host_objects) 844 num_shards = len(shard_host_map) 845 if num_shards > 1: 846 return False 847 if num_shards == 1: 848 hosts_on_shard = shard_host_map.values()[0] 849 assert len(hosts_on_shard) <= len(host_objects) 850 return len(hosts_on_shard) == len(host_objects) 851 else: 852 return True 853 854 855def encode_ascii(control_file): 856 """Force a control file to only contain ascii characters. 857 858 @param control_file: Control file to encode. 859 860 @returns the control file in an ascii encoding. 861 862 @raises error.ControlFileMalformed: if encoding fails. 863 """ 864 try: 865 return control_file.encode('ascii') 866 except UnicodeDecodeError as e: 867 raise error.ControlFileMalformed(str(e)) 868 869 870def get_wmatrix_url(): 871 """Get wmatrix url from config file. 872 873 @returns the wmatrix url or an empty string. 874 """ 875 return global_config.global_config.get_config_value('AUTOTEST_WEB', 876 'wmatrix_url', 877 default='') 878 879 880def inject_times_to_filter(start_time_key=None, end_time_key=None, 881 start_time_value=None, end_time_value=None, 882 **filter_data): 883 """Inject the key value pairs of start and end time if provided. 884 885 @param start_time_key: A string represents the filter key of start_time. 886 @param end_time_key: A string represents the filter key of end_time. 887 @param start_time_value: Start_time value. 888 @param end_time_value: End_time value. 889 890 @returns the injected filter_data. 891 """ 892 if start_time_value: 893 filter_data[start_time_key] = start_time_value 894 if end_time_value: 895 filter_data[end_time_key] = end_time_value 896 return filter_data 897 898 899def inject_times_to_hqe_special_tasks_filters(filter_data_common, 900 start_time, end_time): 901 """Inject start and end time to hqe and special tasks filters. 902 903 @param filter_data_common: Common filter for hqe and special tasks. 904 @param start_time_key: A string represents the filter key of start_time. 905 @param end_time_key: A string represents the filter key of end_time. 906 907 @returns a pair of hqe and special tasks filters. 908 """ 909 filter_data_special_tasks = filter_data_common.copy() 910 return (inject_times_to_filter('started_on__gte', 'started_on__lte', 911 start_time, end_time, **filter_data_common), 912 inject_times_to_filter('time_started__gte', 'time_started__lte', 913 start_time, end_time, 914 **filter_data_special_tasks)) 915 916 917def retrieve_shard(shard_hostname): 918 """ 919 Retrieves the shard with the given hostname from the database. 920 921 @param shard_hostname: Hostname of the shard to retrieve 922 923 @raises models.Shard.DoesNotExist, if no shard with this hostname was found. 924 925 @returns: Shard object 926 """ 927 return models.Shard.smart_get(shard_hostname) 928 929 930def find_records_for_shard(shard, known_job_ids, known_host_ids): 931 """Find records that should be sent to a shard. 932 933 @param shard: Shard to find records for. 934 @param known_job_ids: List of ids of jobs the shard already has. 935 @param known_host_ids: List of ids of hosts the shard already has. 936 937 @returns: Tuple of lists: 938 (hosts, jobs, suite_job_keyvals, invalid_host_ids) 939 """ 940 hosts, invalid_host_ids = models.Host.assign_to_shard( 941 shard, known_host_ids) 942 jobs = models.Job.assign_to_shard(shard, known_job_ids) 943 parent_job_ids = [job.parent_job_id for job in jobs] 944 suite_job_keyvals = models.JobKeyval.objects.filter( 945 job_id__in=parent_job_ids) 946 return hosts, jobs, suite_job_keyvals, invalid_host_ids 947 948 949def _persist_records_with_type_sent_from_shard( 950 shard, records, record_type, *args, **kwargs): 951 """ 952 Handle records of a specified type that were sent to the shard master. 953 954 @param shard: The shard the records were sent from. 955 @param records: The records sent in their serialized format. 956 @param record_type: Type of the objects represented by records. 957 @param args: Additional arguments that will be passed on to the sanity 958 checks. 959 @param kwargs: Additional arguments that will be passed on to the sanity 960 checks. 961 962 @raises error.UnallowedRecordsSentToMaster if any of the sanity checks fail. 963 964 @returns: List of primary keys of the processed records. 965 """ 966 pks = [] 967 for serialized_record in records: 968 pk = serialized_record['id'] 969 try: 970 current_record = record_type.objects.get(pk=pk) 971 except record_type.DoesNotExist: 972 raise error.UnallowedRecordsSentToMaster( 973 'Object with pk %s of type %s does not exist on master.' % ( 974 pk, record_type)) 975 976 try: 977 current_record.sanity_check_update_from_shard( 978 shard, serialized_record, *args, **kwargs) 979 except error.IgnorableUnallowedRecordsSentToMaster: 980 # An illegal record change was attempted, but it was of a non-fatal 981 # variety. Silently skip this record. 982 pass 983 else: 984 current_record.update_from_serialized(serialized_record) 985 pks.append(pk) 986 987 return pks 988 989 990def persist_records_sent_from_shard(shard, jobs, hqes): 991 """ 992 Sanity checking then saving serialized records sent to master from shard. 993 994 During heartbeats shards upload jobs and hostqueuentries. This performs 995 some sanity checks on these and then updates the existing records for those 996 entries with the updated ones from the heartbeat. 997 998 The sanity checks include: 999 - Checking if the objects sent already exist on the master. 1000 - Checking if the objects sent were assigned to this shard. 1001 - hostqueueentries must be sent together with their jobs. 1002 1003 @param shard: The shard the records were sent from. 1004 @param jobs: The jobs the shard sent. 1005 @param hqes: The hostqueuentries the shart sent. 1006 1007 @raises error.UnallowedRecordsSentToMaster if any of the sanity checks fail. 1008 """ 1009 job_ids_persisted = _persist_records_with_type_sent_from_shard( 1010 shard, jobs, models.Job) 1011 _persist_records_with_type_sent_from_shard( 1012 shard, hqes, models.HostQueueEntry, 1013 job_ids_sent=job_ids_persisted) 1014 1015 1016def forward_single_host_rpc_to_shard(func): 1017 """This decorator forwards rpc calls that modify a host to a shard. 1018 1019 If a host is assigned to a shard, rpcs that change his attributes should be 1020 forwarded to the shard. 1021 1022 This assumes the first argument of the function represents a host id. 1023 1024 @param func: The function to decorate 1025 1026 @returns: The function to replace func with. 1027 """ 1028 def replacement(**kwargs): 1029 # Only keyword arguments can be accepted here, as we need the argument 1030 # names to send the rpc. serviceHandler always provides arguments with 1031 # their keywords, so this is not a problem. 1032 1033 # A host record (identified by kwargs['id']) can be deleted in 1034 # func(). Therefore, we should save the data that can be needed later 1035 # before func() is called. 1036 shard_hostname = None 1037 host = models.Host.smart_get(kwargs['id']) 1038 if host and host.shard: 1039 shard_hostname = host.shard.rpc_hostname() 1040 ret = func(**kwargs) 1041 if shard_hostname and not server_utils.is_shard(): 1042 run_rpc_on_multiple_hostnames(func.func_name, 1043 [shard_hostname], 1044 **kwargs) 1045 return ret 1046 1047 return replacement 1048 1049 1050def fanout_rpc(host_objs, rpc_name, include_hostnames=True, **kwargs): 1051 """Fanout the given rpc to shards of given hosts. 1052 1053 @param host_objs: Host objects for the rpc. 1054 @param rpc_name: The name of the rpc. 1055 @param include_hostnames: If True, include the hostnames in the kwargs. 1056 Hostnames are not always necessary, this functions is designed to 1057 send rpcs to the shard a host is on, the rpcs themselves could be 1058 related to labels, acls etc. 1059 @param kwargs: The kwargs for the rpc. 1060 """ 1061 # Figure out which hosts are on which shards. 1062 shard_host_map = bucket_hosts_by_shard( 1063 host_objs, rpc_hostnames=True) 1064 1065 # Execute the rpc against the appropriate shards. 1066 for shard, hostnames in shard_host_map.iteritems(): 1067 if include_hostnames: 1068 kwargs['hosts'] = hostnames 1069 try: 1070 run_rpc_on_multiple_hostnames(rpc_name, [shard], **kwargs) 1071 except: 1072 ei = sys.exc_info() 1073 new_exc = error.RPCException('RPC %s failed on shard %s due to ' 1074 '%s: %s' % (rpc_name, shard, ei[0].__name__, ei[1])) 1075 raise new_exc.__class__, new_exc, ei[2] 1076 1077 1078def run_rpc_on_multiple_hostnames(rpc_call, shard_hostnames, **kwargs): 1079 """Runs an rpc to multiple AFEs 1080 1081 This is i.e. used to propagate changes made to hosts after they are assigned 1082 to a shard. 1083 1084 @param rpc_call: Name of the rpc endpoint to call. 1085 @param shard_hostnames: List of hostnames to run the rpcs on. 1086 @param **kwargs: Keyword arguments to pass in the rpcs. 1087 """ 1088 # Make sure this function is not called on shards but only on master. 1089 assert not server_utils.is_shard() 1090 for shard_hostname in shard_hostnames: 1091 afe = frontend_wrappers.RetryingAFE(server=shard_hostname, 1092 user=thread_local.get_user()) 1093 afe.run(rpc_call, **kwargs) 1094 1095 1096def get_label(name): 1097 """Gets a label object using a given name. 1098 1099 @param name: Label name. 1100 @raises model.Label.DoesNotExist: when there is no label matching 1101 the given name. 1102 @return: a label object matching the given name. 1103 """ 1104 try: 1105 label = models.Label.smart_get(name) 1106 except models.Label.DoesNotExist: 1107 return None 1108 return label 1109 1110 1111# TODO: hide the following rpcs under is_moblab 1112def moblab_only(func): 1113 """Ensure moblab specific functions only run on Moblab devices.""" 1114 def verify(*args, **kwargs): 1115 if not server_utils.is_moblab(): 1116 raise error.RPCException('RPC: %s can only run on Moblab Systems!', 1117 func.__name__) 1118 return func(*args, **kwargs) 1119 return verify 1120 1121 1122def route_rpc_to_master(func): 1123 """Route RPC to master AFE. 1124 1125 When a shard receives an RPC decorated by this, the RPC is just 1126 forwarded to the master. 1127 When the master gets the RPC, the RPC function is executed. 1128 1129 @param func: An RPC function to decorate 1130 1131 @returns: A function replacing the RPC func. 1132 """ 1133 argspec = inspect.getargspec(func) 1134 if argspec.varargs is not None: 1135 raise Exception('RPC function must not have *args.') 1136 1137 @wraps(func) 1138 def replacement(*args, **kwargs): 1139 """We need special handling when decorating an RPC that can be called 1140 directly using positional arguments. 1141 1142 One example is rpc_interface.create_job(). 1143 rpc_interface.create_job_page_handler() calls the function using both 1144 positional and keyword arguments. Since frontend.RpcClient.run() 1145 takes only keyword arguments for an RPC, positional arguments of the 1146 RPC function need to be transformed into keyword arguments. 1147 """ 1148 kwargs = _convert_to_kwargs_only(func, args, kwargs) 1149 if server_utils.is_shard(): 1150 afe = frontend_wrappers.RetryingAFE( 1151 server=server_utils.get_global_afe_hostname(), 1152 user=thread_local.get_user()) 1153 return afe.run(func.func_name, **kwargs) 1154 return func(**kwargs) 1155 1156 return replacement 1157 1158 1159def _convert_to_kwargs_only(func, args, kwargs): 1160 """Convert a function call's arguments to a kwargs dict. 1161 1162 This is best illustrated with an example. Given: 1163 1164 def foo(a, b, **kwargs): 1165 pass 1166 _to_kwargs(foo, (1, 2), {'c': 3}) # corresponding to foo(1, 2, c=3) 1167 1168 foo(**kwargs) 1169 1170 @param func: function whose signature to use 1171 @param args: positional arguments of call 1172 @param kwargs: keyword arguments of call 1173 1174 @returns: kwargs dict 1175 """ 1176 argspec = inspect.getargspec(func) 1177 # callargs looks like {'a': 1, 'b': 2, 'kwargs': {'c': 3}} 1178 callargs = inspect.getcallargs(func, *args, **kwargs) 1179 if argspec.keywords is None: 1180 kwargs = {} 1181 else: 1182 kwargs = callargs.pop(argspec.keywords) 1183 kwargs.update(callargs) 1184 return kwargs 1185 1186 1187def get_sample_dut(board, pool): 1188 """Get a dut with the given board and pool. 1189 1190 This method is used to help to locate a dut with the given board and pool. 1191 The dut then can be used to identify a devserver in the same subnet. 1192 1193 @param board: Name of the board. 1194 @param pool: Name of the pool. 1195 1196 @return: Name of a dut with the given board and pool. 1197 """ 1198 if not (dev_server.PREFER_LOCAL_DEVSERVER and pool and board): 1199 return None 1200 hosts = list(get_host_query( 1201 multiple_labels=('pool:%s' % pool, 'board:%s' % board), 1202 exclude_only_if_needed_labels=False, 1203 valid_only=True, 1204 filter_data={}, 1205 )) 1206 if not hosts: 1207 return None 1208 else: 1209 return hosts[0].hostname 1210