1# pylint: disable=missing-docstring 2""" 3Utility functions for rpc_interface.py. We keep them in a separate file so that 4only RPC interface functions go into that file. 5""" 6 7__author__ = 'showard@google.com (Steve Howard)' 8 9import collections 10import datetime 11from functools import wraps 12import inspect 13import logging 14import os 15import sys 16import django.db.utils 17import django.http 18 19from autotest_lib.frontend import thread_local 20from autotest_lib.frontend.afe import models, model_logic 21from autotest_lib.client.common_lib import control_data, error 22from autotest_lib.client.common_lib import global_config 23from autotest_lib.client.common_lib import time_utils 24from autotest_lib.client.common_lib.cros import dev_server 25from autotest_lib.server import utils as server_utils 26from autotest_lib.server.cros import provision 27from autotest_lib.server.cros.dynamic_suite import frontend_wrappers 28 29NULL_DATETIME = datetime.datetime.max 30NULL_DATE = datetime.date.max 31DUPLICATE_KEY_MSG = 'Duplicate entry' 32RESPECT_STATIC_LABELS = global_config.global_config.get_config_value( 33 'SKYLAB', 'respect_static_labels', type=bool, default=False) 34 35def prepare_for_serialization(objects): 36 """ 37 Prepare Python objects to be returned via RPC. 38 @param objects: objects to be prepared. 39 """ 40 if (isinstance(objects, list) and len(objects) and 41 isinstance(objects[0], dict) and 'id' in objects[0]): 42 objects = _gather_unique_dicts(objects) 43 return _prepare_data(objects) 44 45 46def prepare_rows_as_nested_dicts(query, nested_dict_column_names): 47 """ 48 Prepare a Django query to be returned via RPC as a sequence of nested 49 dictionaries. 50 51 @param query - A Django model query object with a select_related() method. 52 @param nested_dict_column_names - A list of column/attribute names for the 53 rows returned by query to expand into nested dictionaries using 54 their get_object_dict() method when not None. 55 56 @returns An list suitable to returned in an RPC. 57 """ 58 all_dicts = [] 59 for row in query.select_related(): 60 row_dict = row.get_object_dict() 61 for column in nested_dict_column_names: 62 if row_dict[column] is not None: 63 row_dict[column] = getattr(row, column).get_object_dict() 64 all_dicts.append(row_dict) 65 return prepare_for_serialization(all_dicts) 66 67 68def _prepare_data(data): 69 """ 70 Recursively process data structures, performing necessary type 71 conversions to values in data to allow for RPC serialization: 72 -convert datetimes to strings 73 -convert tuples and sets to lists 74 """ 75 if isinstance(data, dict): 76 new_data = {} 77 for key, value in data.iteritems(): 78 new_data[key] = _prepare_data(value) 79 return new_data 80 elif (isinstance(data, list) or isinstance(data, tuple) or 81 isinstance(data, set)): 82 return [_prepare_data(item) for item in data] 83 elif isinstance(data, datetime.date): 84 if data is NULL_DATETIME or data is NULL_DATE: 85 return None 86 return str(data) 87 else: 88 return data 89 90 91def fetchall_as_list_of_dicts(cursor): 92 """ 93 Converts each row in the cursor to a dictionary so that values can be read 94 by using the column name. 95 @param cursor: The database cursor to read from. 96 @returns: A list of each row in the cursor as a dictionary. 97 """ 98 desc = cursor.description 99 return [ dict(zip([col[0] for col in desc], row)) 100 for row in cursor.fetchall() ] 101 102 103def raw_http_response(response_data, content_type=None): 104 response = django.http.HttpResponse(response_data, mimetype=content_type) 105 response['Content-length'] = str(len(response.content)) 106 return response 107 108 109def _gather_unique_dicts(dict_iterable): 110 """\ 111 Pick out unique objects (by ID) from an iterable of object dicts. 112 """ 113 objects = collections.OrderedDict() 114 for obj in dict_iterable: 115 objects.setdefault(obj['id'], obj) 116 return objects.values() 117 118 119def extra_job_status_filters(not_yet_run=False, running=False, finished=False): 120 """\ 121 Generate a SQL WHERE clause for job status filtering, and return it in 122 a dict of keyword args to pass to query.extra(). 123 * not_yet_run: all HQEs are Queued 124 * finished: all HQEs are complete 125 * running: everything else 126 """ 127 if not (not_yet_run or running or finished): 128 return {} 129 not_queued = ('(SELECT job_id FROM afe_host_queue_entries ' 130 'WHERE status != "%s")' 131 % models.HostQueueEntry.Status.QUEUED) 132 not_finished = ('(SELECT job_id FROM afe_host_queue_entries ' 133 'WHERE not complete)') 134 135 where = [] 136 if not_yet_run: 137 where.append('id NOT IN ' + not_queued) 138 if running: 139 where.append('(id IN %s) AND (id IN %s)' % (not_queued, not_finished)) 140 if finished: 141 where.append('id NOT IN ' + not_finished) 142 return {'where': [' OR '.join(['(%s)' % x for x in where])]} 143 144 145def extra_job_type_filters(extra_args, suite=False, 146 sub=False, standalone=False): 147 """\ 148 Generate a SQL WHERE clause for job status filtering, and return it in 149 a dict of keyword args to pass to query.extra(). 150 151 param extra_args: a dict of existing extra_args. 152 153 No more than one of the parameters should be passed as True: 154 * suite: job which is parent of other jobs 155 * sub: job with a parent job 156 * standalone: job with no child or parent jobs 157 """ 158 assert not ((suite and sub) or 159 (suite and standalone) or 160 (sub and standalone)), ('Cannot specify more than one ' 161 'filter to this function') 162 163 where = extra_args.get('where', []) 164 parent_job_id = ('DISTINCT parent_job_id') 165 child_job_id = ('id') 166 filter_common = ('(SELECT %s FROM afe_jobs ' 167 'WHERE parent_job_id IS NOT NULL)') 168 169 if suite: 170 where.append('id IN ' + filter_common % parent_job_id) 171 elif sub: 172 where.append('id IN ' + filter_common % child_job_id) 173 elif standalone: 174 where.append('NOT EXISTS (SELECT 1 from afe_jobs AS sub_query ' 175 'WHERE parent_job_id IS NOT NULL' 176 ' AND (sub_query.parent_job_id=afe_jobs.id' 177 ' OR sub_query.id=afe_jobs.id))') 178 else: 179 return extra_args 180 181 extra_args['where'] = where 182 return extra_args 183 184 185def get_host_query(multiple_labels, exclude_only_if_needed_labels, 186 valid_only, filter_data): 187 """ 188 @param exclude_only_if_needed_labels: Deprecated. By default it's false. 189 """ 190 if valid_only: 191 initial_query = models.Host.valid_objects.all() 192 else: 193 initial_query = models.Host.objects.all() 194 195 try: 196 hosts = models.Host.get_hosts_with_labels( 197 multiple_labels, initial_query) 198 if not hosts: 199 return hosts 200 201 return models.Host.query_objects(filter_data, initial_query=hosts) 202 except models.Label.DoesNotExist: 203 return models.Host.objects.none() 204 205 206class InconsistencyException(Exception): 207 'Raised when a list of objects does not have a consistent value' 208 209 210def get_consistent_value(objects, field): 211 if not objects: 212 # well a list of nothing is consistent 213 return None 214 215 value = getattr(objects[0], field) 216 for obj in objects: 217 this_value = getattr(obj, field) 218 if this_value != value: 219 raise InconsistencyException(objects[0], obj) 220 return value 221 222 223def afe_test_dict_to_test_object(test_dict): 224 if not isinstance(test_dict, dict): 225 return test_dict 226 227 numerized_dict = {} 228 for key, value in test_dict.iteritems(): 229 try: 230 numerized_dict[key] = int(value) 231 except (ValueError, TypeError): 232 numerized_dict[key] = value 233 234 return type('TestObject', (object,), numerized_dict) 235 236 237def _check_is_server_test(test_type): 238 """Checks if the test type is a server test. 239 240 @param test_type The test type in enum integer or string. 241 242 @returns A boolean to identify if the test type is server test. 243 """ 244 if test_type is not None: 245 if isinstance(test_type, basestring): 246 try: 247 test_type = control_data.CONTROL_TYPE.get_value(test_type) 248 except AttributeError: 249 return False 250 return (test_type == control_data.CONTROL_TYPE.SERVER) 251 return False 252 253 254def prepare_generate_control_file(tests, profilers, db_tests=True): 255 if db_tests: 256 test_objects = [models.Test.smart_get(test) for test in tests] 257 else: 258 test_objects = [afe_test_dict_to_test_object(test) for test in tests] 259 260 profiler_objects = [models.Profiler.smart_get(profiler) 261 for profiler in profilers] 262 # ensure tests are all the same type 263 try: 264 test_type = get_consistent_value(test_objects, 'test_type') 265 except InconsistencyException, exc: 266 test1, test2 = exc.args 267 raise model_logic.ValidationError( 268 {'tests' : 'You cannot run both test_suites and server-side ' 269 'tests together (tests %s and %s differ' % ( 270 test1.name, test2.name)}) 271 272 is_server = _check_is_server_test(test_type) 273 if test_objects: 274 synch_count = max(test.sync_count for test in test_objects) 275 else: 276 synch_count = 1 277 278 if db_tests: 279 dependencies = set(label.name for label 280 in models.Label.objects.filter(test__in=test_objects)) 281 else: 282 dependencies = reduce( 283 set.union, [set(test.dependencies) for test in test_objects]) 284 285 cf_info = dict(is_server=is_server, synch_count=synch_count, 286 dependencies=list(dependencies)) 287 return cf_info, test_objects, profiler_objects 288 289 290def check_job_dependencies(host_objects, job_dependencies): 291 """ 292 Check that a set of machines satisfies a job's dependencies. 293 host_objects: list of models.Host objects 294 job_dependencies: list of names of labels 295 """ 296 # check that hosts satisfy dependencies 297 host_ids = [host.id for host in host_objects] 298 hosts_in_job = models.Host.objects.filter(id__in=host_ids) 299 ok_hosts = hosts_in_job 300 for index, dependency in enumerate(job_dependencies): 301 if not provision.is_for_special_action(dependency): 302 try: 303 label = models.Label.smart_get(dependency) 304 except models.Label.DoesNotExist: 305 logging.info('Label %r does not exist, so it cannot ' 306 'be replaced by static label.', dependency) 307 label = None 308 309 if label is not None and label.is_replaced_by_static(): 310 ok_hosts = ok_hosts.filter(static_labels__name=dependency) 311 else: 312 ok_hosts = ok_hosts.filter(labels__name=dependency) 313 314 failing_hosts = (set(host.hostname for host in host_objects) - 315 set(host.hostname for host in ok_hosts)) 316 if failing_hosts: 317 raise model_logic.ValidationError( 318 {'hosts' : 'Host(s) failed to meet job dependencies (' + 319 (', '.join(job_dependencies)) + '): ' + 320 (', '.join(failing_hosts))}) 321 322 323def check_job_metahost_dependencies(metahost_objects, job_dependencies): 324 """ 325 Check that at least one machine within the metahost spec satisfies the job's 326 dependencies. 327 328 @param metahost_objects A list of label objects representing the metahosts. 329 @param job_dependencies A list of strings of the required label names. 330 @raises NoEligibleHostException If a metahost cannot run the job. 331 """ 332 for metahost in metahost_objects: 333 if metahost.is_replaced_by_static(): 334 static_metahost = models.StaticLabel.smart_get(metahost.name) 335 hosts = models.Host.objects.filter(static_labels=static_metahost) 336 else: 337 hosts = models.Host.objects.filter(labels=metahost) 338 339 for label_name in job_dependencies: 340 if not provision.is_for_special_action(label_name): 341 try: 342 label = models.Label.smart_get(label_name) 343 except models.Label.DoesNotExist: 344 logging.info('Label %r does not exist, so it cannot ' 345 'be replaced by static label.', label_name) 346 label = None 347 348 if label is not None and label.is_replaced_by_static(): 349 hosts = hosts.filter(static_labels__name=label_name) 350 else: 351 hosts = hosts.filter(labels__name=label_name) 352 353 if not any(hosts): 354 raise error.NoEligibleHostException("No hosts within %s satisfy %s." 355 % (metahost.name, ', '.join(job_dependencies))) 356 357 358def _execution_key_for(host_queue_entry): 359 return (host_queue_entry.job.id, host_queue_entry.execution_subdir) 360 361 362def check_abort_synchronous_jobs(host_queue_entries): 363 # ensure user isn't aborting part of a synchronous autoserv execution 364 count_per_execution = {} 365 for queue_entry in host_queue_entries: 366 key = _execution_key_for(queue_entry) 367 count_per_execution.setdefault(key, 0) 368 count_per_execution[key] += 1 369 370 for queue_entry in host_queue_entries: 371 if not queue_entry.execution_subdir: 372 continue 373 execution_count = count_per_execution[_execution_key_for(queue_entry)] 374 if execution_count < queue_entry.job.synch_count: 375 raise model_logic.ValidationError( 376 {'' : 'You cannot abort part of a synchronous job execution ' 377 '(%d/%s), %d included, %d expected' 378 % (queue_entry.job.id, queue_entry.execution_subdir, 379 execution_count, queue_entry.job.synch_count)}) 380 381 382def check_modify_host(update_data): 383 """ 384 Sanity check modify_host* requests. 385 386 @param update_data: A dictionary with the changes to make to a host 387 or hosts. 388 """ 389 # Only the scheduler (monitor_db) is allowed to modify Host status. 390 # Otherwise race conditions happen as a hosts state is changed out from 391 # beneath tasks being run on a host. 392 if 'status' in update_data: 393 raise model_logic.ValidationError({ 394 'status': 'Host status can not be modified by the frontend.'}) 395 396 397def check_modify_host_locking(host, update_data): 398 """ 399 Checks when locking/unlocking has been requested if the host is already 400 locked/unlocked. 401 402 @param host: models.Host object to be modified 403 @param update_data: A dictionary with the changes to make to the host. 404 """ 405 locked = update_data.get('locked', None) 406 lock_reason = update_data.get('lock_reason', None) 407 if locked is not None: 408 if locked and host.locked: 409 raise model_logic.ValidationError({ 410 'locked': 'Host %s already locked by %s on %s.' % 411 (host.hostname, host.locked_by, host.lock_time)}) 412 if not locked and not host.locked: 413 raise model_logic.ValidationError({ 414 'locked': 'Host %s already unlocked.' % host.hostname}) 415 if locked and not lock_reason and not host.locked: 416 raise model_logic.ValidationError({ 417 'locked': 'Please provide a reason for locking Host %s' % 418 host.hostname}) 419 420 421def get_motd(): 422 dirname = os.path.dirname(__file__) 423 filename = os.path.join(dirname, "..", "..", "motd.txt") 424 text = '' 425 try: 426 fp = open(filename, "r") 427 try: 428 text = fp.read() 429 finally: 430 fp.close() 431 except: 432 pass 433 434 return text 435 436 437def _get_metahost_counts(metahost_objects): 438 metahost_counts = {} 439 for metahost in metahost_objects: 440 metahost_counts.setdefault(metahost, 0) 441 metahost_counts[metahost] += 1 442 return metahost_counts 443 444 445def get_job_info(job, preserve_metahosts=False, queue_entry_filter_data=None): 446 hosts = [] 447 one_time_hosts = [] 448 meta_hosts = [] 449 hostless = False 450 451 queue_entries = job.hostqueueentry_set.all() 452 if queue_entry_filter_data: 453 queue_entries = models.HostQueueEntry.query_objects( 454 queue_entry_filter_data, initial_query=queue_entries) 455 456 for queue_entry in queue_entries: 457 if (queue_entry.host and (preserve_metahosts or 458 not queue_entry.meta_host)): 459 if queue_entry.deleted: 460 continue 461 if queue_entry.host.invalid: 462 one_time_hosts.append(queue_entry.host) 463 else: 464 hosts.append(queue_entry.host) 465 elif queue_entry.meta_host: 466 meta_hosts.append(queue_entry.meta_host) 467 else: 468 hostless = True 469 470 meta_host_counts = _get_metahost_counts(meta_hosts) 471 472 info = dict(dependencies=[label.name for label 473 in job.dependency_labels.all()], 474 hosts=hosts, 475 meta_hosts=meta_hosts, 476 meta_host_counts=meta_host_counts, 477 one_time_hosts=one_time_hosts, 478 hostless=hostless) 479 return info 480 481 482def check_for_duplicate_hosts(host_objects): 483 host_counts = collections.Counter(host_objects) 484 duplicate_hostnames = {host.hostname 485 for host, count in host_counts.iteritems() 486 if count > 1} 487 if duplicate_hostnames: 488 raise model_logic.ValidationError( 489 {'hosts' : 'Duplicate hosts: %s' 490 % ', '.join(duplicate_hostnames)}) 491 492 493def create_new_job(owner, options, host_objects, metahost_objects): 494 all_host_objects = host_objects + metahost_objects 495 dependencies = options.get('dependencies', []) 496 synch_count = options.get('synch_count') 497 498 if synch_count is not None and synch_count > len(all_host_objects): 499 raise model_logic.ValidationError( 500 {'hosts': 501 'only %d hosts provided for job with synch_count = %d' % 502 (len(all_host_objects), synch_count)}) 503 504 check_for_duplicate_hosts(host_objects) 505 506 for label_name in dependencies: 507 if provision.is_for_special_action(label_name): 508 # TODO: We could save a few queries 509 # if we had a bulk ensure-label-exists function, which used 510 # a bulk .get() call. The win is probably very small. 511 _ensure_label_exists(label_name) 512 513 # This only checks targeted hosts, not hosts eligible due to the metahost 514 check_job_dependencies(host_objects, dependencies) 515 check_job_metahost_dependencies(metahost_objects, dependencies) 516 517 options['dependencies'] = list( 518 models.Label.objects.filter(name__in=dependencies)) 519 520 job = models.Job.create(owner=owner, options=options, 521 hosts=all_host_objects) 522 job.queue(all_host_objects, 523 is_template=options.get('is_template', False)) 524 return job.id 525 526 527def _ensure_label_exists(name): 528 """ 529 Ensure that a label called |name| exists in the Django models. 530 531 This function is to be called from within afe rpcs only, as an 532 alternative to server.cros.provision.ensure_label_exists(...). It works 533 by Django model manipulation, rather than by making another create_label 534 rpc call. 535 536 @param name: the label to check for/create. 537 @raises ValidationError: There was an error in the response that was 538 not because the label already existed. 539 @returns True is a label was created, False otherwise. 540 """ 541 # Make sure this function is not called on shards but only on main. 542 assert not server_utils.is_shard() 543 try: 544 models.Label.objects.get(name=name) 545 except models.Label.DoesNotExist: 546 try: 547 new_label = models.Label.objects.create(name=name) 548 new_label.save() 549 return True 550 except django.db.utils.IntegrityError as e: 551 # It is possible that another suite/test already 552 # created the label between the check and save. 553 if DUPLICATE_KEY_MSG in str(e): 554 return False 555 else: 556 raise 557 return False 558 559 560def find_platform(hostname, label_list): 561 """ 562 Figure out the platform name for the given host 563 object. If none, the return value for either will be None. 564 565 @param hostname: The hostname to find platform. 566 @param label_list: The label list to find platform. 567 568 @returns platform name for the given host. 569 """ 570 platforms = [label.name for label in label_list if label.platform] 571 if not platforms: 572 platform = None 573 else: 574 platform = platforms[0] 575 576 if len(platforms) > 1: 577 raise ValueError('Host %s has more than one platform: %s' % 578 (hostname, ', '.join(platforms))) 579 580 return platform 581 582 583# support for get_host_queue_entries_and_special_tasks() 584 585def _common_entry_to_dict(entry, type, job_dict, exec_path, status, started_on): 586 return dict(type=type, 587 host=entry['host'], 588 job=job_dict, 589 execution_path=exec_path, 590 status=status, 591 started_on=started_on, 592 id=str(entry['id']) + type, 593 oid=entry['id']) 594 595 596def _special_task_to_dict(task, queue_entries): 597 """Transforms a special task dictionary to another form of dictionary. 598 599 @param task Special task as a dictionary type 600 @param queue_entries Host queue entries as a list of dictionaries. 601 602 @return Transformed dictionary for a special task. 603 """ 604 job_dict = None 605 if task['queue_entry']: 606 # Scan queue_entries to get the job detail info. 607 for qentry in queue_entries: 608 if task['queue_entry']['id'] == qentry['id']: 609 job_dict = qentry['job'] 610 break 611 # If not found, get it from DB. 612 if job_dict is None: 613 job = models.Job.objects.get(id=task['queue_entry']['job']) 614 job_dict = job.get_object_dict() 615 616 exec_path = server_utils.get_special_task_exec_path( 617 task['host']['hostname'], task['id'], task['task'], 618 time_utils.time_string_to_datetime(task['time_requested'])) 619 status = server_utils.get_special_task_status( 620 task['is_complete'], task['success'], task['is_active']) 621 return _common_entry_to_dict(task, task['task'], job_dict, 622 exec_path, status, task['time_started']) 623 624 625def _queue_entry_to_dict(queue_entry): 626 job_dict = queue_entry['job'] 627 tag = server_utils.get_job_tag(job_dict['id'], job_dict['owner']) 628 exec_path = server_utils.get_hqe_exec_path(tag, 629 queue_entry['execution_subdir']) 630 return _common_entry_to_dict(queue_entry, 'Job', job_dict, exec_path, 631 queue_entry['status'], queue_entry['started_on']) 632 633 634def prepare_host_queue_entries_and_special_tasks(interleaved_entries, 635 queue_entries): 636 """ 637 Prepare for serialization the interleaved entries of host queue entries 638 and special tasks. 639 Each element in the entries is a dictionary type. 640 The special task dictionary has only a job id for a job and lacks 641 the detail of the job while the host queue entry dictionary has. 642 queue_entries is used to look up the job detail info. 643 644 @param interleaved_entries Host queue entries and special tasks as a list 645 of dictionaries. 646 @param queue_entries Host queue entries as a list of dictionaries. 647 648 @return A post-processed list of dictionaries that is to be serialized. 649 """ 650 dict_list = [] 651 for e in interleaved_entries: 652 # Distinguish the two mixed entries based on the existence of 653 # the key "task". If an entry has the key, the entry is for 654 # special task. Otherwise, host queue entry. 655 if 'task' in e: 656 dict_list.append(_special_task_to_dict(e, queue_entries)) 657 else: 658 dict_list.append(_queue_entry_to_dict(e)) 659 return prepare_for_serialization(dict_list) 660 661 662def _compute_next_job_for_tasks(queue_entries, special_tasks): 663 """ 664 For each task, try to figure out the next job that ran after that task. 665 This is done using two pieces of information: 666 * if the task has a queue entry, we can use that entry's job ID. 667 * if the task has a time_started, we can try to compare that against the 668 started_on field of queue_entries. this isn't guaranteed to work perfectly 669 since queue_entries may also have null started_on values. 670 * if the task has neither, or if use of time_started fails, just use the 671 last computed job ID. 672 673 @param queue_entries Host queue entries as a list of dictionaries. 674 @param special_tasks Special tasks as a list of dictionaries. 675 """ 676 next_job_id = None # most recently computed next job 677 hqe_index = 0 # index for scanning by started_on times 678 for task in special_tasks: 679 if task['queue_entry']: 680 next_job_id = task['queue_entry']['job'] 681 elif task['time_started'] is not None: 682 for queue_entry in queue_entries[hqe_index:]: 683 if queue_entry['started_on'] is None: 684 continue 685 t1 = time_utils.time_string_to_datetime( 686 queue_entry['started_on']) 687 t2 = time_utils.time_string_to_datetime(task['time_started']) 688 if t1 < t2: 689 break 690 next_job_id = queue_entry['job']['id'] 691 692 task['next_job_id'] = next_job_id 693 694 # advance hqe_index to just after next_job_id 695 if next_job_id is not None: 696 for queue_entry in queue_entries[hqe_index:]: 697 if queue_entry['job']['id'] < next_job_id: 698 break 699 hqe_index += 1 700 701 702def interleave_entries(queue_entries, special_tasks): 703 """ 704 Both lists should be ordered by descending ID. 705 """ 706 _compute_next_job_for_tasks(queue_entries, special_tasks) 707 708 # start with all special tasks that've run since the last job 709 interleaved_entries = [] 710 for task in special_tasks: 711 if task['next_job_id'] is not None: 712 break 713 interleaved_entries.append(task) 714 715 # now interleave queue entries with the remaining special tasks 716 special_task_index = len(interleaved_entries) 717 for queue_entry in queue_entries: 718 interleaved_entries.append(queue_entry) 719 # add all tasks that ran between this job and the previous one 720 for task in special_tasks[special_task_index:]: 721 if task['next_job_id'] < queue_entry['job']['id']: 722 break 723 interleaved_entries.append(task) 724 special_task_index += 1 725 726 return interleaved_entries 727 728 729def bucket_hosts_by_shard(host_objs): 730 """Figure out which hosts are on which shards. 731 732 @param host_objs: A list of host objects. 733 734 @return: A map of shard hostname: list of hosts on the shard. 735 """ 736 shard_host_map = collections.defaultdict(list) 737 for host in host_objs: 738 if host.shard: 739 shard_host_map[host.shard.hostname].append(host.hostname) 740 return shard_host_map 741 742 743def create_job_common( 744 name, 745 priority, 746 control_type, 747 control_file=None, 748 hosts=(), 749 meta_hosts=(), 750 one_time_hosts=(), 751 synch_count=None, 752 is_template=False, 753 timeout=None, 754 timeout_mins=None, 755 max_runtime_mins=None, 756 run_verify=True, 757 email_list='', 758 dependencies=(), 759 reboot_before=None, 760 reboot_after=None, 761 parse_failed_repair=None, 762 hostless=False, 763 keyvals=None, 764 drone_set=None, 765 parent_job_id=None, 766 run_reset=True, 767 require_ssp=None): 768 #pylint: disable-msg=C0111 769 """ 770 Common code between creating "standard" jobs and creating parameterized jobs 771 """ 772 # input validation 773 host_args_passed = any((hosts, meta_hosts, one_time_hosts)) 774 if hostless: 775 if host_args_passed: 776 raise model_logic.ValidationError({ 777 'hostless': 'Hostless jobs cannot include any hosts!'}) 778 if control_type != control_data.CONTROL_TYPE_NAMES.SERVER: 779 raise model_logic.ValidationError({ 780 'control_type': 'Hostless jobs cannot use client-side ' 781 'control files'}) 782 elif not host_args_passed: 783 raise model_logic.ValidationError({ 784 'arguments' : "For host jobs, you must pass at least one of" 785 " 'hosts', 'meta_hosts', 'one_time_hosts'." 786 }) 787 label_objects = list(models.Label.objects.filter(name__in=meta_hosts)) 788 789 # convert hostnames & meta hosts to host/label objects 790 host_objects = models.Host.smart_get_bulk(hosts) 791 _validate_host_job_sharding(host_objects) 792 for host in one_time_hosts: 793 this_host = models.Host.create_one_time_host(host) 794 host_objects.append(this_host) 795 796 metahost_objects = [] 797 meta_host_labels_by_name = {label.name: label for label in label_objects} 798 for label_name in meta_hosts: 799 if label_name in meta_host_labels_by_name: 800 metahost_objects.append(meta_host_labels_by_name[label_name]) 801 else: 802 raise model_logic.ValidationError( 803 {'meta_hosts' : 'Label "%s" not found' % label_name}) 804 805 options = dict(name=name, 806 priority=priority, 807 control_file=control_file, 808 control_type=control_type, 809 is_template=is_template, 810 timeout=timeout, 811 timeout_mins=timeout_mins, 812 max_runtime_mins=max_runtime_mins, 813 synch_count=synch_count, 814 run_verify=run_verify, 815 email_list=email_list, 816 dependencies=dependencies, 817 reboot_before=reboot_before, 818 reboot_after=reboot_after, 819 parse_failed_repair=parse_failed_repair, 820 keyvals=keyvals, 821 drone_set=drone_set, 822 parent_job_id=parent_job_id, 823 # TODO(crbug.com/873716) DEPRECATED. Remove entirely. 824 test_retry=0, 825 run_reset=run_reset, 826 require_ssp=require_ssp) 827 828 return create_new_job(owner=models.User.current_user().login, 829 options=options, 830 host_objects=host_objects, 831 metahost_objects=metahost_objects) 832 833 834def _validate_host_job_sharding(host_objects): 835 """Check that the hosts obey job sharding rules.""" 836 if not (server_utils.is_shard() 837 or _allowed_hosts_for_main_job(host_objects)): 838 shard_host_map = bucket_hosts_by_shard(host_objects) 839 raise ValueError( 840 'The following hosts are on shard(s), please create ' 841 'seperate jobs for hosts on each shard: %s ' % 842 shard_host_map) 843 844 845def _allowed_hosts_for_main_job(host_objects): 846 """Check that the hosts are allowed for a job on main.""" 847 # We disallow the following jobs on main: 848 # num_shards > 1: this is a job spanning across multiple shards. 849 # num_shards == 1 but number of hosts on shard is less 850 # than total number of hosts: this is a job that spans across 851 # one shard and the main. 852 shard_host_map = bucket_hosts_by_shard(host_objects) 853 num_shards = len(shard_host_map) 854 if num_shards > 1: 855 return False 856 if num_shards == 1: 857 hosts_on_shard = shard_host_map.values()[0] 858 assert len(hosts_on_shard) <= len(host_objects) 859 return len(hosts_on_shard) == len(host_objects) 860 else: 861 return True 862 863 864def encode_ascii(control_file): 865 """Force a control file to only contain ascii characters. 866 867 @param control_file: Control file to encode. 868 869 @returns the control file in an ascii encoding. 870 871 @raises error.ControlFileMalformed: if encoding fails. 872 """ 873 try: 874 return control_file.encode('ascii') 875 except UnicodeDecodeError as e: 876 raise error.ControlFileMalformed(str(e)) 877 878 879def get_wmatrix_url(): 880 """Get wmatrix url from config file. 881 882 @returns the wmatrix url or an empty string. 883 """ 884 return global_config.global_config.get_config_value('AUTOTEST_WEB', 885 'wmatrix_url', 886 default='') 887 888 889def get_stainless_url(): 890 """Get stainless url from config file. 891 892 @returns the stainless url or an empty string. 893 """ 894 return global_config.global_config.get_config_value('AUTOTEST_WEB', 895 'stainless_url', 896 default='') 897 898 899def inject_times_to_filter(start_time_key=None, end_time_key=None, 900 start_time_value=None, end_time_value=None, 901 **filter_data): 902 """Inject the key value pairs of start and end time if provided. 903 904 @param start_time_key: A string represents the filter key of start_time. 905 @param end_time_key: A string represents the filter key of end_time. 906 @param start_time_value: Start_time value. 907 @param end_time_value: End_time value. 908 909 @returns the injected filter_data. 910 """ 911 if start_time_value: 912 filter_data[start_time_key] = start_time_value 913 if end_time_value: 914 filter_data[end_time_key] = end_time_value 915 return filter_data 916 917 918def inject_times_to_hqe_special_tasks_filters(filter_data_common, 919 start_time, end_time): 920 """Inject start and end time to hqe and special tasks filters. 921 922 @param filter_data_common: Common filter for hqe and special tasks. 923 @param start_time_key: A string represents the filter key of start_time. 924 @param end_time_key: A string represents the filter key of end_time. 925 926 @returns a pair of hqe and special tasks filters. 927 """ 928 filter_data_special_tasks = filter_data_common.copy() 929 return (inject_times_to_filter('started_on__gte', 'started_on__lte', 930 start_time, end_time, **filter_data_common), 931 inject_times_to_filter('time_started__gte', 'time_started__lte', 932 start_time, end_time, 933 **filter_data_special_tasks)) 934 935 936def retrieve_shard(shard_hostname): 937 """ 938 Retrieves the shard with the given hostname from the database. 939 940 @param shard_hostname: Hostname of the shard to retrieve 941 942 @raises models.Shard.DoesNotExist, if no shard with this hostname was found. 943 944 @returns: Shard object 945 """ 946 return models.Shard.smart_get(shard_hostname) 947 948 949def find_records_for_shard(shard, known_job_ids, known_host_ids): 950 """Find records that should be sent to a shard. 951 952 @param shard: Shard to find records for. 953 @param known_job_ids: List of ids of jobs the shard already has. 954 @param known_host_ids: List of ids of hosts the shard already has. 955 956 @returns: Tuple of lists: 957 (hosts, jobs, suite_job_keyvals, invalid_host_ids) 958 """ 959 hosts, invalid_host_ids = models.Host.assign_to_shard( 960 shard, known_host_ids) 961 jobs = models.Job.assign_to_shard(shard, known_job_ids) 962 parent_job_ids = [job.parent_job_id for job in jobs] 963 suite_job_keyvals = models.JobKeyval.objects.filter( 964 job_id__in=parent_job_ids) 965 return hosts, jobs, suite_job_keyvals, invalid_host_ids 966 967 968def _persist_records_with_type_sent_from_shard( 969 shard, records, record_type, *args, **kwargs): 970 """ 971 Handle records of a specified type that were sent to the shard main. 972 973 @param shard: The shard the records were sent from. 974 @param records: The records sent in their serialized format. 975 @param record_type: Type of the objects represented by records. 976 @param args: Additional arguments that will be passed on to the sanity 977 checks. 978 @param kwargs: Additional arguments that will be passed on to the sanity 979 checks. 980 981 @raises error.UnallowedRecordsSentToMain if any of the sanity checks fail. 982 983 @returns: List of primary keys of the processed records. 984 """ 985 pks = [] 986 for serialized_record in records: 987 pk = serialized_record['id'] 988 try: 989 current_record = record_type.objects.get(pk=pk) 990 except record_type.DoesNotExist: 991 raise error.UnallowedRecordsSentToMain( 992 'Object with pk %s of type %s does not exist on main.' % ( 993 pk, record_type)) 994 995 try: 996 current_record.sanity_check_update_from_shard( 997 shard, serialized_record, *args, **kwargs) 998 except error.IgnorableUnallowedRecordsSentToMain: 999 # An illegal record change was attempted, but it was of a non-fatal 1000 # variety. Silently skip this record. 1001 pass 1002 else: 1003 current_record.update_from_serialized(serialized_record) 1004 pks.append(pk) 1005 1006 return pks 1007 1008 1009def persist_records_sent_from_shard(shard, jobs, hqes): 1010 """ 1011 Sanity checking then saving serialized records sent to main from shard. 1012 1013 During heartbeats shards upload jobs and hostqueuentries. This performs 1014 some sanity checks on these and then updates the existing records for those 1015 entries with the updated ones from the heartbeat. 1016 1017 The sanity checks include: 1018 - Checking if the objects sent already exist on the main. 1019 - Checking if the objects sent were assigned to this shard. 1020 - hostqueueentries must be sent together with their jobs. 1021 1022 @param shard: The shard the records were sent from. 1023 @param jobs: The jobs the shard sent. 1024 @param hqes: The hostqueuentries the shart sent. 1025 1026 @raises error.UnallowedRecordsSentToMain if any of the sanity checks fail. 1027 """ 1028 job_ids_persisted = _persist_records_with_type_sent_from_shard( 1029 shard, jobs, models.Job) 1030 _persist_records_with_type_sent_from_shard( 1031 shard, hqes, models.HostQueueEntry, 1032 job_ids_sent=job_ids_persisted) 1033 1034 1035def forward_single_host_rpc_to_shard(func): 1036 """This decorator forwards rpc calls that modify a host to a shard. 1037 1038 If a host is assigned to a shard, rpcs that change his attributes should be 1039 forwarded to the shard. 1040 1041 This assumes the first argument of the function represents a host id. 1042 1043 @param func: The function to decorate 1044 1045 @returns: The function to replace func with. 1046 """ 1047 def replacement(**kwargs): 1048 # Only keyword arguments can be accepted here, as we need the argument 1049 # names to send the rpc. serviceHandler always provides arguments with 1050 # their keywords, so this is not a problem. 1051 1052 # A host record (identified by kwargs['id']) can be deleted in 1053 # func(). Therefore, we should save the data that can be needed later 1054 # before func() is called. 1055 shard_hostname = None 1056 host = models.Host.smart_get(kwargs['id']) 1057 if host and host.shard: 1058 shard_hostname = host.shard.hostname 1059 ret = func(**kwargs) 1060 if shard_hostname and not server_utils.is_shard(): 1061 run_rpc_on_multiple_hostnames(func.func_name, 1062 [shard_hostname], 1063 **kwargs) 1064 return ret 1065 1066 return replacement 1067 1068 1069def fanout_rpc(host_objs, rpc_name, include_hostnames=True, **kwargs): 1070 """Fanout the given rpc to shards of given hosts. 1071 1072 @param host_objs: Host objects for the rpc. 1073 @param rpc_name: The name of the rpc. 1074 @param include_hostnames: If True, include the hostnames in the kwargs. 1075 Hostnames are not always necessary, this functions is designed to 1076 send rpcs to the shard a host is on, the rpcs themselves could be 1077 related to labels, acls etc. 1078 @param kwargs: The kwargs for the rpc. 1079 """ 1080 # Figure out which hosts are on which shards. 1081 shard_host_map = bucket_hosts_by_shard(host_objs) 1082 1083 # Execute the rpc against the appropriate shards. 1084 for shard, hostnames in shard_host_map.iteritems(): 1085 if include_hostnames: 1086 kwargs['hosts'] = hostnames 1087 try: 1088 run_rpc_on_multiple_hostnames(rpc_name, [shard], **kwargs) 1089 except: 1090 ei = sys.exc_info() 1091 new_exc = error.RPCException('RPC %s failed on shard %s due to ' 1092 '%s: %s' % (rpc_name, shard, ei[0].__name__, ei[1])) 1093 raise new_exc.__class__, new_exc, ei[2] 1094 1095 1096def run_rpc_on_multiple_hostnames(rpc_call, shard_hostnames, **kwargs): 1097 """Runs an rpc to multiple AFEs 1098 1099 This is i.e. used to propagate changes made to hosts after they are assigned 1100 to a shard. 1101 1102 @param rpc_call: Name of the rpc endpoint to call. 1103 @param shard_hostnames: List of hostnames to run the rpcs on. 1104 @param **kwargs: Keyword arguments to pass in the rpcs. 1105 """ 1106 # Make sure this function is not called on shards but only on main. 1107 assert not server_utils.is_shard() 1108 for shard_hostname in shard_hostnames: 1109 afe = frontend_wrappers.RetryingAFE(server=shard_hostname, 1110 user=thread_local.get_user()) 1111 afe.run(rpc_call, **kwargs) 1112 1113 1114def get_label(name): 1115 """Gets a label object using a given name. 1116 1117 @param name: Label name. 1118 @raises model.Label.DoesNotExist: when there is no label matching 1119 the given name. 1120 @return: a label object matching the given name. 1121 """ 1122 try: 1123 label = models.Label.smart_get(name) 1124 except models.Label.DoesNotExist: 1125 return None 1126 return label 1127 1128 1129# TODO: hide the following rpcs under is_moblab 1130def moblab_only(func): 1131 """Ensure moblab specific functions only run on Moblab devices.""" 1132 def verify(*args, **kwargs): 1133 if not server_utils.is_moblab(): 1134 raise error.RPCException('RPC: %s can only run on Moblab Systems!', 1135 func.__name__) 1136 return func(*args, **kwargs) 1137 return verify 1138 1139 1140def route_rpc_to_main(func): 1141 """Route RPC to main AFE. 1142 1143 When a shard receives an RPC decorated by this, the RPC is just 1144 forwarded to the main. 1145 When the main gets the RPC, the RPC function is executed. 1146 1147 @param func: An RPC function to decorate 1148 1149 @returns: A function replacing the RPC func. 1150 """ 1151 argspec = inspect.getargspec(func) 1152 if argspec.varargs is not None: 1153 raise Exception('RPC function must not have *args.') 1154 1155 @wraps(func) 1156 def replacement(*args, **kwargs): 1157 """We need special handling when decorating an RPC that can be called 1158 directly using positional arguments. 1159 1160 One example is rpc_interface.create_job(). 1161 rpc_interface.create_job_page_handler() calls the function using both 1162 positional and keyword arguments. Since frontend.RpcClient.run() 1163 takes only keyword arguments for an RPC, positional arguments of the 1164 RPC function need to be transformed into keyword arguments. 1165 """ 1166 kwargs = _convert_to_kwargs_only(func, args, kwargs) 1167 if server_utils.is_shard(): 1168 afe = frontend_wrappers.RetryingAFE( 1169 server=server_utils.get_global_afe_hostname(), 1170 user=thread_local.get_user()) 1171 return afe.run(func.func_name, **kwargs) 1172 return func(**kwargs) 1173 1174 return replacement 1175 1176 1177def _convert_to_kwargs_only(func, args, kwargs): 1178 """Convert a function call's arguments to a kwargs dict. 1179 1180 This is best illustrated with an example. Given: 1181 1182 def foo(a, b, **kwargs): 1183 pass 1184 _to_kwargs(foo, (1, 2), {'c': 3}) # corresponding to foo(1, 2, c=3) 1185 1186 foo(**kwargs) 1187 1188 @param func: function whose signature to use 1189 @param args: positional arguments of call 1190 @param kwargs: keyword arguments of call 1191 1192 @returns: kwargs dict 1193 """ 1194 argspec = inspect.getargspec(func) 1195 # callargs looks like {'a': 1, 'b': 2, 'kwargs': {'c': 3}} 1196 callargs = inspect.getcallargs(func, *args, **kwargs) 1197 if argspec.keywords is None: 1198 kwargs = {} 1199 else: 1200 kwargs = callargs.pop(argspec.keywords) 1201 kwargs.update(callargs) 1202 return kwargs 1203 1204 1205def get_sample_dut(board, pool): 1206 """Get a dut with the given board and pool. 1207 1208 This method is used to help to locate a dut with the given board and pool. 1209 The dut then can be used to identify a devserver in the same subnet. 1210 1211 @param board: Name of the board. 1212 @param pool: Name of the pool. 1213 1214 @return: Name of a dut with the given board and pool. 1215 """ 1216 if not (dev_server.PREFER_LOCAL_DEVSERVER and pool and board): 1217 return None 1218 1219 hosts = list(get_host_query( 1220 multiple_labels=('pool:%s' % pool, 'board:%s' % board), 1221 exclude_only_if_needed_labels=False, 1222 valid_only=True, 1223 filter_data={}, 1224 )) 1225 if not hosts: 1226 return None 1227 else: 1228 return hosts[0].hostname 1229