1# pylint: disable=missing-docstring 2""" 3Utility functions for rpc_interface.py. We keep them in a separate file so that 4only RPC interface functions go into that file. 5""" 6 7__author__ = 'showard@google.com (Steve Howard)' 8 9import collections 10import datetime 11import inspect 12import logging 13import os 14from functools import wraps 15 16import django.db.utils 17import django.http 18import six 19from autotest_lib.client.common_lib import (control_data, error, global_config, 20 time_utils) 21from autotest_lib.client.common_lib.cros import dev_server 22from autotest_lib.frontend import thread_local 23from autotest_lib.frontend.afe import model_logic, models 24from autotest_lib.server import utils as server_utils 25from autotest_lib.server.cros import provision 26from autotest_lib.server.cros.dynamic_suite import frontend_wrappers 27 28NULL_DATETIME = datetime.datetime.max 29NULL_DATE = datetime.date.max 30DUPLICATE_KEY_MSG = 'Duplicate entry' 31RESPECT_STATIC_LABELS = global_config.global_config.get_config_value( 32 'SKYLAB', 'respect_static_labels', type=bool, default=False) 33 34def prepare_for_serialization(objects): 35 """ 36 Prepare Python objects to be returned via RPC. 37 @param objects: objects to be prepared. 38 """ 39 if (isinstance(objects, list) and len(objects) and 40 isinstance(objects[0], dict) and 'id' in objects[0]): 41 objects = _gather_unique_dicts(objects) 42 return _prepare_data(objects) 43 44 45def prepare_rows_as_nested_dicts(query, nested_dict_column_names): 46 """ 47 Prepare a Django query to be returned via RPC as a sequence of nested 48 dictionaries. 49 50 @param query - A Django model query object with a select_related() method. 51 @param nested_dict_column_names - A list of column/attribute names for the 52 rows returned by query to expand into nested dictionaries using 53 their get_object_dict() method when not None. 54 55 @returns An list suitable to returned in an RPC. 56 """ 57 all_dicts = [] 58 for row in query.select_related(): 59 row_dict = row.get_object_dict() 60 for column in nested_dict_column_names: 61 if row_dict[column] is not None: 62 row_dict[column] = getattr(row, column).get_object_dict() 63 all_dicts.append(row_dict) 64 return prepare_for_serialization(all_dicts) 65 66 67def _prepare_data(data): 68 """ 69 Recursively process data structures, performing necessary type 70 conversions to values in data to allow for RPC serialization: 71 -convert datetimes to strings 72 -convert tuples and sets to lists 73 """ 74 if isinstance(data, dict): 75 new_data = {} 76 for key, value in six.iteritems(data): 77 new_data[key] = _prepare_data(value) 78 return new_data 79 elif (isinstance(data, list) or isinstance(data, tuple) or 80 isinstance(data, set)): 81 return [_prepare_data(item) for item in data] 82 elif isinstance(data, datetime.date): 83 if data is NULL_DATETIME or data is NULL_DATE: 84 return None 85 return str(data) 86 else: 87 return data 88 89 90def fetchall_as_list_of_dicts(cursor): 91 """ 92 Converts each row in the cursor to a dictionary so that values can be read 93 by using the column name. 94 @param cursor: The database cursor to read from. 95 @returns: A list of each row in the cursor as a dictionary. 96 """ 97 desc = cursor.description 98 return [ dict(zip([col[0] for col in desc], row)) 99 for row in cursor.fetchall() ] 100 101 102def raw_http_response(response_data, content_type=None): 103 response = django.http.HttpResponse(response_data, mimetype=content_type) 104 response['Content-length'] = str(len(response.content)) 105 return response 106 107 108def _gather_unique_dicts(dict_iterable): 109 """\ 110 Pick out unique objects (by ID) from an iterable of object dicts. 111 """ 112 objects = collections.OrderedDict() 113 for obj in dict_iterable: 114 objects.setdefault(obj['id'], obj) 115 return list(objects.values()) 116 117 118def extra_job_status_filters(not_yet_run=False, running=False, finished=False): 119 """\ 120 Generate a SQL WHERE clause for job status filtering, and return it in 121 a dict of keyword args to pass to query.extra(). 122 * not_yet_run: all HQEs are Queued 123 * finished: all HQEs are complete 124 * running: everything else 125 """ 126 if not (not_yet_run or running or finished): 127 return {} 128 not_queued = ('(SELECT job_id FROM afe_host_queue_entries ' 129 'WHERE status != "%s")' 130 % models.HostQueueEntry.Status.QUEUED) 131 not_finished = ('(SELECT job_id FROM afe_host_queue_entries ' 132 'WHERE not complete)') 133 134 where = [] 135 if not_yet_run: 136 where.append('id NOT IN ' + not_queued) 137 if running: 138 where.append('(id IN %s) AND (id IN %s)' % (not_queued, not_finished)) 139 if finished: 140 where.append('id NOT IN ' + not_finished) 141 return {'where': [' OR '.join(['(%s)' % x for x in where])]} 142 143 144def extra_job_type_filters(extra_args, suite=False, 145 sub=False, standalone=False): 146 """\ 147 Generate a SQL WHERE clause for job status filtering, and return it in 148 a dict of keyword args to pass to query.extra(). 149 150 param extra_args: a dict of existing extra_args. 151 152 No more than one of the parameters should be passed as True: 153 * suite: job which is parent of other jobs 154 * sub: job with a parent job 155 * standalone: job with no child or parent jobs 156 """ 157 assert not ((suite and sub) or 158 (suite and standalone) or 159 (sub and standalone)), ('Cannot specify more than one ' 160 'filter to this function') 161 162 where = extra_args.get('where', []) 163 parent_job_id = ('DISTINCT parent_job_id') 164 child_job_id = ('id') 165 filter_common = ('(SELECT %s FROM afe_jobs ' 166 'WHERE parent_job_id IS NOT NULL)') 167 168 if suite: 169 where.append('id IN ' + filter_common % parent_job_id) 170 elif sub: 171 where.append('id IN ' + filter_common % child_job_id) 172 elif standalone: 173 where.append('NOT EXISTS (SELECT 1 from afe_jobs AS sub_query ' 174 'WHERE parent_job_id IS NOT NULL' 175 ' AND (sub_query.parent_job_id=afe_jobs.id' 176 ' OR sub_query.id=afe_jobs.id))') 177 else: 178 return extra_args 179 180 extra_args['where'] = where 181 return extra_args 182 183 184def get_host_query(multiple_labels, exclude_only_if_needed_labels, 185 valid_only, filter_data): 186 """ 187 @param exclude_only_if_needed_labels: Deprecated. By default it's false. 188 """ 189 if valid_only: 190 initial_query = models.Host.valid_objects.all() 191 else: 192 initial_query = models.Host.objects.all() 193 194 try: 195 hosts = models.Host.get_hosts_with_labels( 196 multiple_labels, initial_query) 197 if not hosts: 198 return hosts 199 200 return models.Host.query_objects(filter_data, initial_query=hosts) 201 except models.Label.DoesNotExist: 202 return models.Host.objects.none() 203 204 205class InconsistencyException(Exception): 206 'Raised when a list of objects does not have a consistent value' 207 208 209def get_consistent_value(objects, field): 210 if not objects: 211 # well a list of nothing is consistent 212 return None 213 214 value = getattr(objects[0], field) 215 for obj in objects: 216 this_value = getattr(obj, field) 217 if this_value != value: 218 raise InconsistencyException(objects[0], obj) 219 return value 220 221 222def afe_test_dict_to_test_object(test_dict): 223 if not isinstance(test_dict, dict): 224 return test_dict 225 226 numerized_dict = {} 227 for key, value in six.iteritems(test_dict): 228 try: 229 numerized_dict[key] = int(value) 230 except (ValueError, TypeError): 231 numerized_dict[key] = value 232 233 return type('TestObject', (object,), numerized_dict) 234 235 236def _check_is_server_test(test_type): 237 """Checks if the test type is a server test. 238 239 @param test_type The test type in enum integer or string. 240 241 @returns A boolean to identify if the test type is server test. 242 """ 243 if test_type is not None: 244 if isinstance(test_type, six.string_types): 245 try: 246 test_type = control_data.CONTROL_TYPE.get_value(test_type) 247 except AttributeError: 248 return False 249 return (test_type == control_data.CONTROL_TYPE.SERVER) 250 return False 251 252 253def prepare_generate_control_file(tests, profilers, db_tests=True): 254 if db_tests: 255 test_objects = [models.Test.smart_get(test) for test in tests] 256 else: 257 test_objects = [afe_test_dict_to_test_object(test) for test in tests] 258 259 profiler_objects = [models.Profiler.smart_get(profiler) 260 for profiler in profilers] 261 # ensure tests are all the same type 262 try: 263 test_type = get_consistent_value(test_objects, 'test_type') 264 except InconsistencyException as exc: 265 test1, test2 = exc.args 266 raise model_logic.ValidationError( 267 {'tests' : 'You cannot run both test_suites and server-side ' 268 'tests together (tests %s and %s differ' % ( 269 test1.name, test2.name)}) 270 271 is_server = _check_is_server_test(test_type) 272 if test_objects: 273 synch_count = max(test.sync_count for test in test_objects) 274 else: 275 synch_count = 1 276 277 if db_tests: 278 dependencies = set(label.name for label 279 in models.Label.objects.filter(test__in=test_objects)) 280 else: 281 dependencies = reduce( 282 set.union, [set(test.dependencies) for test in test_objects]) 283 284 cf_info = dict(is_server=is_server, synch_count=synch_count, 285 dependencies=list(dependencies)) 286 return cf_info, test_objects, profiler_objects 287 288 289def check_job_dependencies(host_objects, job_dependencies): 290 """ 291 Check that a set of machines satisfies a job's dependencies. 292 host_objects: list of models.Host objects 293 job_dependencies: list of names of labels 294 """ 295 # check that hosts satisfy dependencies 296 host_ids = [host.id for host in host_objects] 297 hosts_in_job = models.Host.objects.filter(id__in=host_ids) 298 ok_hosts = hosts_in_job 299 for index, dependency in enumerate(job_dependencies): 300 if not provision.is_for_special_action(dependency): 301 try: 302 label = models.Label.smart_get(dependency) 303 except models.Label.DoesNotExist: 304 logging.info( 305 'Label %r does not exist, so it cannot ' 306 'be replaced by static label.', dependency) 307 label = None 308 309 if label is not None and label.is_replaced_by_static(): 310 ok_hosts = ok_hosts.filter(static_labels__name=dependency) 311 else: 312 ok_hosts = ok_hosts.filter(labels__name=dependency) 313 314 failing_hosts = (set(host.hostname for host in host_objects) - 315 set(host.hostname for host in ok_hosts)) 316 if failing_hosts: 317 raise model_logic.ValidationError( 318 {'hosts' : 'Host(s) failed to meet job dependencies (' + 319 (', '.join(job_dependencies)) + '): ' + 320 (', '.join(failing_hosts))}) 321 322 323def check_job_metahost_dependencies(metahost_objects, job_dependencies): 324 """ 325 Check that at least one machine within the metahost spec satisfies the job's 326 dependencies. 327 328 @param metahost_objects A list of label objects representing the metahosts. 329 @param job_dependencies A list of strings of the required label names. 330 @raises NoEligibleHostException If a metahost cannot run the job. 331 """ 332 for metahost in metahost_objects: 333 if metahost.is_replaced_by_static(): 334 static_metahost = models.StaticLabel.smart_get(metahost.name) 335 hosts = models.Host.objects.filter(static_labels=static_metahost) 336 else: 337 hosts = models.Host.objects.filter(labels=metahost) 338 339 for label_name in job_dependencies: 340 if not provision.is_for_special_action(label_name): 341 try: 342 label = models.Label.smart_get(label_name) 343 except models.Label.DoesNotExist: 344 logging.info('Label %r does not exist, so it cannot ' 345 'be replaced by static label.', label_name) 346 label = None 347 348 if label is not None and label.is_replaced_by_static(): 349 hosts = hosts.filter(static_labels__name=label_name) 350 else: 351 hosts = hosts.filter(labels__name=label_name) 352 353 if not any(hosts): 354 raise error.NoEligibleHostException("No hosts within %s satisfy %s." 355 % (metahost.name, ', '.join(job_dependencies))) 356 357 358def _execution_key_for(host_queue_entry): 359 return (host_queue_entry.job.id, host_queue_entry.execution_subdir) 360 361 362def check_abort_synchronous_jobs(host_queue_entries): 363 # ensure user isn't aborting part of a synchronous autoserv execution 364 count_per_execution = {} 365 for queue_entry in host_queue_entries: 366 key = _execution_key_for(queue_entry) 367 count_per_execution.setdefault(key, 0) 368 count_per_execution[key] += 1 369 370 for queue_entry in host_queue_entries: 371 if not queue_entry.execution_subdir: 372 continue 373 execution_count = count_per_execution[_execution_key_for(queue_entry)] 374 if execution_count < queue_entry.job.synch_count: 375 raise model_logic.ValidationError( 376 {'' : 'You cannot abort part of a synchronous job execution ' 377 '(%d/%s), %d included, %d expected' 378 % (queue_entry.job.id, queue_entry.execution_subdir, 379 execution_count, queue_entry.job.synch_count)}) 380 381 382def check_modify_host(update_data): 383 """ 384 Check modify_host* requests. 385 386 @param update_data: A dictionary with the changes to make to a host 387 or hosts. 388 """ 389 # Only the scheduler (monitor_db) is allowed to modify Host status. 390 # Otherwise race conditions happen as a hosts state is changed out from 391 # beneath tasks being run on a host. 392 if 'status' in update_data: 393 raise model_logic.ValidationError({ 394 'status': 'Host status can not be modified by the frontend.'}) 395 396 397def check_modify_host_locking(host, update_data): 398 """ 399 Checks when locking/unlocking has been requested if the host is already 400 locked/unlocked. 401 402 @param host: models.Host object to be modified 403 @param update_data: A dictionary with the changes to make to the host. 404 """ 405 locked = update_data.get('locked', None) 406 lock_reason = update_data.get('lock_reason', None) 407 if locked is not None: 408 if locked and host.locked: 409 raise model_logic.ValidationError({ 410 'locked': 'Host %s already locked by %s on %s.' % 411 (host.hostname, host.locked_by, host.lock_time)}) 412 if not locked and not host.locked: 413 raise model_logic.ValidationError({ 414 'locked': 'Host %s already unlocked.' % host.hostname}) 415 if locked and not lock_reason and not host.locked: 416 raise model_logic.ValidationError({ 417 'locked': 'Please provide a reason for locking Host %s' % 418 host.hostname}) 419 420 421def get_motd(): 422 dirname = os.path.dirname(__file__) 423 filename = os.path.join(dirname, "..", "..", "motd.txt") 424 text = '' 425 try: 426 fp = open(filename, "r") 427 try: 428 text = fp.read() 429 finally: 430 fp.close() 431 except: 432 pass 433 434 return text 435 436 437def _get_metahost_counts(metahost_objects): 438 metahost_counts = {} 439 for metahost in metahost_objects: 440 metahost_counts.setdefault(metahost, 0) 441 metahost_counts[metahost] += 1 442 return metahost_counts 443 444 445def get_job_info(job, preserve_metahosts=False, queue_entry_filter_data=None): 446 hosts = [] 447 one_time_hosts = [] 448 meta_hosts = [] 449 hostless = False 450 451 queue_entries = job.hostqueueentry_set.all() 452 if queue_entry_filter_data: 453 queue_entries = models.HostQueueEntry.query_objects( 454 queue_entry_filter_data, initial_query=queue_entries) 455 456 for queue_entry in queue_entries: 457 if (queue_entry.host and (preserve_metahosts or 458 not queue_entry.meta_host)): 459 if queue_entry.deleted: 460 continue 461 if queue_entry.host.invalid: 462 one_time_hosts.append(queue_entry.host) 463 else: 464 hosts.append(queue_entry.host) 465 elif queue_entry.meta_host: 466 meta_hosts.append(queue_entry.meta_host) 467 else: 468 hostless = True 469 470 meta_host_counts = _get_metahost_counts(meta_hosts) 471 472 info = dict(dependencies=[label.name for label 473 in job.dependency_labels.all()], 474 hosts=hosts, 475 meta_hosts=meta_hosts, 476 meta_host_counts=meta_host_counts, 477 one_time_hosts=one_time_hosts, 478 hostless=hostless) 479 return info 480 481 482def check_for_duplicate_hosts(host_objects): 483 host_counts = collections.Counter(host_objects) 484 duplicate_hostnames = { 485 host.hostname 486 for host, count in six.iteritems(host_counts) if count > 1 487 } 488 if duplicate_hostnames: 489 raise model_logic.ValidationError( 490 {'hosts' : 'Duplicate hosts: %s' 491 % ', '.join(duplicate_hostnames)}) 492 493 494def create_new_job(owner, options, host_objects, metahost_objects): 495 all_host_objects = host_objects + metahost_objects 496 dependencies = options.get('dependencies', []) 497 synch_count = options.get('synch_count') 498 499 if synch_count is not None and synch_count > len(all_host_objects): 500 raise model_logic.ValidationError( 501 {'hosts': 502 'only %d hosts provided for job with synch_count = %d' % 503 (len(all_host_objects), synch_count)}) 504 505 check_for_duplicate_hosts(host_objects) 506 507 for label_name in dependencies: 508 if provision.is_for_special_action(label_name): 509 # TODO: We could save a few queries 510 # if we had a bulk ensure-label-exists function, which used 511 # a bulk .get() call. The win is probably very small. 512 _ensure_label_exists(label_name) 513 514 # This only checks targeted hosts, not hosts eligible due to the metahost 515 check_job_dependencies(host_objects, dependencies) 516 check_job_metahost_dependencies(metahost_objects, dependencies) 517 518 options['dependencies'] = list( 519 models.Label.objects.filter(name__in=dependencies)) 520 521 job = models.Job.create(owner=owner, options=options, 522 hosts=all_host_objects) 523 job.queue(all_host_objects, 524 is_template=options.get('is_template', False)) 525 return job.id 526 527 528def _ensure_label_exists(name): 529 """ 530 Ensure that a label called |name| exists in the Django models. 531 532 This function is to be called from within afe rpcs only, as an 533 alternative to server.cros.provision.ensure_label_exists(...). It works 534 by Django model manipulation, rather than by making another create_label 535 rpc call. 536 537 @param name: the label to check for/create. 538 @raises ValidationError: There was an error in the response that was 539 not because the label already existed. 540 @returns True is a label was created, False otherwise. 541 """ 542 # Make sure this function is not called on shards but only on main. 543 assert not server_utils.is_shard() 544 try: 545 models.Label.objects.get(name=name) 546 except models.Label.DoesNotExist: 547 try: 548 new_label = models.Label.objects.create(name=name) 549 new_label.save() 550 return True 551 except django.db.utils.IntegrityError as e: 552 # It is possible that another suite/test already 553 # created the label between the check and save. 554 if DUPLICATE_KEY_MSG in str(e): 555 return False 556 else: 557 raise 558 return False 559 560 561def find_platform(hostname, label_list): 562 """ 563 Figure out the platform name for the given host 564 object. If none, the return value for either will be None. 565 566 @param hostname: The hostname to find platform. 567 @param label_list: The label list to find platform. 568 569 @returns platform name for the given host. 570 """ 571 platforms = [label.name for label in label_list if label.platform] 572 if not platforms: 573 platform = None 574 else: 575 platform = platforms[0] 576 577 if len(platforms) > 1: 578 raise ValueError('Host %s has more than one platform: %s' % 579 (hostname, ', '.join(platforms))) 580 581 return platform 582 583 584# support for get_host_queue_entries_and_special_tasks() 585 586def _common_entry_to_dict(entry, type, job_dict, exec_path, status, started_on): 587 return dict(type=type, 588 host=entry['host'], 589 job=job_dict, 590 execution_path=exec_path, 591 status=status, 592 started_on=started_on, 593 id=str(entry['id']) + type, 594 oid=entry['id']) 595 596 597def _special_task_to_dict(task, queue_entries): 598 """Transforms a special task dictionary to another form of dictionary. 599 600 @param task Special task as a dictionary type 601 @param queue_entries Host queue entries as a list of dictionaries. 602 603 @return Transformed dictionary for a special task. 604 """ 605 job_dict = None 606 if task['queue_entry']: 607 # Scan queue_entries to get the job detail info. 608 for qentry in queue_entries: 609 if task['queue_entry']['id'] == qentry['id']: 610 job_dict = qentry['job'] 611 break 612 # If not found, get it from DB. 613 if job_dict is None: 614 job = models.Job.objects.get(id=task['queue_entry']['job']) 615 job_dict = job.get_object_dict() 616 617 exec_path = server_utils.get_special_task_exec_path( 618 task['host']['hostname'], task['id'], task['task'], 619 time_utils.time_string_to_datetime(task['time_requested'])) 620 status = server_utils.get_special_task_status( 621 task['is_complete'], task['success'], task['is_active']) 622 return _common_entry_to_dict(task, task['task'], job_dict, 623 exec_path, status, task['time_started']) 624 625 626def _queue_entry_to_dict(queue_entry): 627 job_dict = queue_entry['job'] 628 tag = server_utils.get_job_tag(job_dict['id'], job_dict['owner']) 629 exec_path = server_utils.get_hqe_exec_path(tag, 630 queue_entry['execution_subdir']) 631 return _common_entry_to_dict(queue_entry, 'Job', job_dict, exec_path, 632 queue_entry['status'], queue_entry['started_on']) 633 634 635def prepare_host_queue_entries_and_special_tasks(interleaved_entries, 636 queue_entries): 637 """ 638 Prepare for serialization the interleaved entries of host queue entries 639 and special tasks. 640 Each element in the entries is a dictionary type. 641 The special task dictionary has only a job id for a job and lacks 642 the detail of the job while the host queue entry dictionary has. 643 queue_entries is used to look up the job detail info. 644 645 @param interleaved_entries Host queue entries and special tasks as a list 646 of dictionaries. 647 @param queue_entries Host queue entries as a list of dictionaries. 648 649 @return A post-processed list of dictionaries that is to be serialized. 650 """ 651 dict_list = [] 652 for e in interleaved_entries: 653 # Distinguish the two mixed entries based on the existence of 654 # the key "task". If an entry has the key, the entry is for 655 # special task. Otherwise, host queue entry. 656 if 'task' in e: 657 dict_list.append(_special_task_to_dict(e, queue_entries)) 658 else: 659 dict_list.append(_queue_entry_to_dict(e)) 660 return prepare_for_serialization(dict_list) 661 662 663def _compute_next_job_for_tasks(queue_entries, special_tasks): 664 """ 665 For each task, try to figure out the next job that ran after that task. 666 This is done using two pieces of information: 667 * if the task has a queue entry, we can use that entry's job ID. 668 * if the task has a time_started, we can try to compare that against the 669 started_on field of queue_entries. this isn't guaranteed to work perfectly 670 since queue_entries may also have null started_on values. 671 * if the task has neither, or if use of time_started fails, just use the 672 last computed job ID. 673 674 @param queue_entries Host queue entries as a list of dictionaries. 675 @param special_tasks Special tasks as a list of dictionaries. 676 """ 677 next_job_id = None # most recently computed next job 678 hqe_index = 0 # index for scanning by started_on times 679 for task in special_tasks: 680 if task['queue_entry']: 681 next_job_id = task['queue_entry']['job'] 682 elif task['time_started'] is not None: 683 for queue_entry in queue_entries[hqe_index:]: 684 if queue_entry['started_on'] is None: 685 continue 686 t1 = time_utils.time_string_to_datetime( 687 queue_entry['started_on']) 688 t2 = time_utils.time_string_to_datetime(task['time_started']) 689 if t1 < t2: 690 break 691 next_job_id = queue_entry['job']['id'] 692 693 task['next_job_id'] = next_job_id 694 695 # advance hqe_index to just after next_job_id 696 if next_job_id is not None: 697 for queue_entry in queue_entries[hqe_index:]: 698 if queue_entry['job']['id'] < next_job_id: 699 break 700 hqe_index += 1 701 702 703def interleave_entries(queue_entries, special_tasks): 704 """ 705 Both lists should be ordered by descending ID. 706 """ 707 _compute_next_job_for_tasks(queue_entries, special_tasks) 708 709 # start with all special tasks that've run since the last job 710 interleaved_entries = [] 711 for task in special_tasks: 712 if task['next_job_id'] is not None: 713 break 714 interleaved_entries.append(task) 715 716 # now interleave queue entries with the remaining special tasks 717 special_task_index = len(interleaved_entries) 718 for queue_entry in queue_entries: 719 interleaved_entries.append(queue_entry) 720 # add all tasks that ran between this job and the previous one 721 for task in special_tasks[special_task_index:]: 722 if task['next_job_id'] < queue_entry['job']['id']: 723 break 724 interleaved_entries.append(task) 725 special_task_index += 1 726 727 return interleaved_entries 728 729 730def bucket_hosts_by_shard(host_objs): 731 """Figure out which hosts are on which shards. 732 733 @param host_objs: A list of host objects. 734 735 @return: A map of shard hostname: list of hosts on the shard. 736 """ 737 shard_host_map = collections.defaultdict(list) 738 for host in host_objs: 739 if host.shard: 740 shard_host_map[host.shard.hostname].append(host.hostname) 741 return shard_host_map 742 743 744def create_job_common( 745 name, 746 priority, 747 control_type, 748 control_file=None, 749 hosts=(), 750 meta_hosts=(), 751 one_time_hosts=(), 752 synch_count=None, 753 is_template=False, 754 timeout=None, 755 timeout_mins=None, 756 max_runtime_mins=None, 757 run_verify=True, 758 email_list='', 759 dependencies=(), 760 reboot_before=None, 761 reboot_after=None, 762 parse_failed_repair=None, 763 hostless=False, 764 keyvals=None, 765 drone_set=None, 766 parent_job_id=None, 767 run_reset=True, 768 require_ssp=None): 769 #pylint: disable-msg=C0111 770 """ 771 Common code between creating "standard" jobs and creating parameterized jobs 772 """ 773 # input validation 774 host_args_passed = any((hosts, meta_hosts, one_time_hosts)) 775 if hostless: 776 if host_args_passed: 777 raise model_logic.ValidationError({ 778 'hostless': 'Hostless jobs cannot include any hosts!'}) 779 if control_type != control_data.CONTROL_TYPE_NAMES.SERVER: 780 raise model_logic.ValidationError({ 781 'control_type': 'Hostless jobs cannot use client-side ' 782 'control files'}) 783 elif not host_args_passed: 784 raise model_logic.ValidationError({ 785 'arguments' : "For host jobs, you must pass at least one of" 786 " 'hosts', 'meta_hosts', 'one_time_hosts'." 787 }) 788 label_objects = list(models.Label.objects.filter(name__in=meta_hosts)) 789 790 # convert hostnames & meta hosts to host/label objects 791 host_objects = models.Host.smart_get_bulk(hosts) 792 _validate_host_job_sharding(host_objects) 793 for host in one_time_hosts: 794 this_host = models.Host.create_one_time_host(host) 795 host_objects.append(this_host) 796 797 metahost_objects = [] 798 meta_host_labels_by_name = {label.name: label for label in label_objects} 799 for label_name in meta_hosts: 800 if label_name in meta_host_labels_by_name: 801 metahost_objects.append(meta_host_labels_by_name[label_name]) 802 else: 803 raise model_logic.ValidationError( 804 {'meta_hosts' : 'Label "%s" not found' % label_name}) 805 806 options = dict(name=name, 807 priority=priority, 808 control_file=control_file, 809 control_type=control_type, 810 is_template=is_template, 811 timeout=timeout, 812 timeout_mins=timeout_mins, 813 max_runtime_mins=max_runtime_mins, 814 synch_count=synch_count, 815 run_verify=run_verify, 816 email_list=email_list, 817 dependencies=dependencies, 818 reboot_before=reboot_before, 819 reboot_after=reboot_after, 820 parse_failed_repair=parse_failed_repair, 821 keyvals=keyvals, 822 drone_set=drone_set, 823 parent_job_id=parent_job_id, 824 # TODO(crbug.com/873716) DEPRECATED. Remove entirely. 825 test_retry=0, 826 run_reset=run_reset, 827 require_ssp=require_ssp) 828 829 return create_new_job(owner=models.User.current_user().login, 830 options=options, 831 host_objects=host_objects, 832 metahost_objects=metahost_objects) 833 834 835def _validate_host_job_sharding(host_objects): 836 """Check that the hosts obey job sharding rules.""" 837 if not (server_utils.is_shard() 838 or _allowed_hosts_for_main_job(host_objects)): 839 shard_host_map = bucket_hosts_by_shard(host_objects) 840 raise ValueError( 841 'The following hosts are on shard(s), please create ' 842 'seperate jobs for hosts on each shard: %s ' % 843 shard_host_map) 844 845 846def _allowed_hosts_for_main_job(host_objects): 847 """Check that the hosts are allowed for a job on main.""" 848 # We disallow the following jobs on main: 849 # num_shards > 1: this is a job spanning across multiple shards. 850 # num_shards == 1 but number of hosts on shard is less 851 # than total number of hosts: this is a job that spans across 852 # one shard and the main. 853 shard_host_map = bucket_hosts_by_shard(host_objects) 854 num_shards = len(shard_host_map) 855 if num_shards > 1: 856 return False 857 if num_shards == 1: 858 hosts_on_shard = list(shard_host_map.values())[0] 859 assert len(hosts_on_shard) <= len(host_objects) 860 return len(hosts_on_shard) == len(host_objects) 861 else: 862 return True 863 864 865def encode_ascii(control_file): 866 """Force a control file to only contain ascii characters. 867 868 @param control_file: Control file to encode. 869 870 @returns the control file in an ascii encoding. 871 872 @raises error.ControlFileMalformed: if encoding fails. 873 """ 874 try: 875 return control_file.encode('ascii') 876 except UnicodeDecodeError as e: 877 raise error.ControlFileMalformed(str(e)) 878 879 880def get_wmatrix_url(): 881 """Get wmatrix url from config file. 882 883 @returns the wmatrix url or an empty string. 884 """ 885 return global_config.global_config.get_config_value('AUTOTEST_WEB', 886 'wmatrix_url', 887 default='') 888 889 890def get_stainless_url(): 891 """Get stainless url from config file. 892 893 @returns the stainless url or an empty string. 894 """ 895 return global_config.global_config.get_config_value('AUTOTEST_WEB', 896 'stainless_url', 897 default='') 898 899 900def inject_times_to_filter(start_time_key=None, end_time_key=None, 901 start_time_value=None, end_time_value=None, 902 **filter_data): 903 """Inject the key value pairs of start and end time if provided. 904 905 @param start_time_key: A string represents the filter key of start_time. 906 @param end_time_key: A string represents the filter key of end_time. 907 @param start_time_value: Start_time value. 908 @param end_time_value: End_time value. 909 910 @returns the injected filter_data. 911 """ 912 if start_time_value: 913 filter_data[start_time_key] = start_time_value 914 if end_time_value: 915 filter_data[end_time_key] = end_time_value 916 return filter_data 917 918 919def inject_times_to_hqe_special_tasks_filters(filter_data_common, 920 start_time, end_time): 921 """Inject start and end time to hqe and special tasks filters. 922 923 @param filter_data_common: Common filter for hqe and special tasks. 924 @param start_time_key: A string represents the filter key of start_time. 925 @param end_time_key: A string represents the filter key of end_time. 926 927 @returns a pair of hqe and special tasks filters. 928 """ 929 filter_data_special_tasks = filter_data_common.copy() 930 return (inject_times_to_filter('started_on__gte', 'started_on__lte', 931 start_time, end_time, **filter_data_common), 932 inject_times_to_filter('time_started__gte', 'time_started__lte', 933 start_time, end_time, 934 **filter_data_special_tasks)) 935 936 937def retrieve_shard(shard_hostname): 938 """ 939 Retrieves the shard with the given hostname from the database. 940 941 @param shard_hostname: Hostname of the shard to retrieve 942 943 @raises models.Shard.DoesNotExist, if no shard with this hostname was found. 944 945 @returns: Shard object 946 """ 947 return models.Shard.smart_get(shard_hostname) 948 949 950def find_records_for_shard(shard, known_job_ids, known_host_ids): 951 """Find records that should be sent to a shard. 952 953 @param shard: Shard to find records for. 954 @param known_job_ids: List of ids of jobs the shard already has. 955 @param known_host_ids: List of ids of hosts the shard already has. 956 957 @returns: Tuple of lists: 958 (hosts, jobs, suite_job_keyvals, invalid_host_ids) 959 """ 960 hosts, invalid_host_ids = models.Host.assign_to_shard( 961 shard, known_host_ids) 962 jobs = models.Job.assign_to_shard(shard, known_job_ids) 963 parent_job_ids = [job.parent_job_id for job in jobs] 964 suite_job_keyvals = models.JobKeyval.objects.filter( 965 job_id__in=parent_job_ids) 966 return hosts, jobs, suite_job_keyvals, invalid_host_ids 967 968 969def _persist_records_with_type_sent_from_shard( 970 shard, records, record_type, *args, **kwargs): 971 """ 972 Handle records of a specified type that were sent to the shard main. 973 974 @param shard: The shard the records were sent from. 975 @param records: The records sent in their serialized format. 976 @param record_type: Type of the objects represented by records. 977 @param args: Additional arguments that will be passed on to the checks. 978 @param kwargs: Additional arguments that will be passed on to the checks. 979 980 @raises error.UnallowedRecordsSentToMain if any of the checks fail. 981 982 @returns: List of primary keys of the processed records. 983 """ 984 pks = [] 985 for serialized_record in records: 986 pk = serialized_record['id'] 987 try: 988 current_record = record_type.objects.get(pk=pk) 989 except record_type.DoesNotExist: 990 raise error.UnallowedRecordsSentToMain( 991 'Object with pk %s of type %s does not exist on main.' % ( 992 pk, record_type)) 993 994 try: 995 current_record._check_update_from_shard( 996 shard, serialized_record, *args, **kwargs) 997 except error.IgnorableUnallowedRecordsSentToMain: 998 # An illegal record change was attempted, but it was of a non-fatal 999 # variety. Silently skip this record. 1000 pass 1001 else: 1002 current_record.update_from_serialized(serialized_record) 1003 pks.append(pk) 1004 1005 return pks 1006 1007 1008def persist_records_sent_from_shard(shard, jobs, hqes): 1009 """ 1010 Checking then saving serialized records sent to main from shard. 1011 1012 During heartbeats shards upload jobs and hostqueuentries. This performs 1013 some checks on these and then updates the existing records for those 1014 entries with the updated ones from the heartbeat. 1015 1016 The checks include: 1017 - Checking if the objects sent already exist on the main. 1018 - Checking if the objects sent were assigned to this shard. 1019 - hostqueueentries must be sent together with their jobs. 1020 1021 @param shard: The shard the records were sent from. 1022 @param jobs: The jobs the shard sent. 1023 @param hqes: The hostqueuentries the shart sent. 1024 1025 @raises error.UnallowedRecordsSentToMain if any of the checks fail. 1026 """ 1027 job_ids_persisted = _persist_records_with_type_sent_from_shard( 1028 shard, jobs, models.Job) 1029 _persist_records_with_type_sent_from_shard( 1030 shard, hqes, models.HostQueueEntry, 1031 job_ids_sent=job_ids_persisted) 1032 1033 1034def forward_single_host_rpc_to_shard(func): 1035 """This decorator forwards rpc calls that modify a host to a shard. 1036 1037 If a host is assigned to a shard, rpcs that change the host attributes should be 1038 forwarded to the shard. 1039 1040 This assumes the first argument of the function represents a host id. 1041 1042 @param func: The function to decorate 1043 1044 @returns: The function to replace func with. 1045 """ 1046 def replacement(**kwargs): 1047 # Only keyword arguments can be accepted here, as we need the argument 1048 # names to send the rpc. serviceHandler always provides arguments with 1049 # their keywords, so this is not a problem. 1050 1051 # A host record (identified by kwargs['id']) can be deleted in 1052 # func(). Therefore, we should save the data that can be needed later 1053 # before func() is called. 1054 shard_hostname = None 1055 host = models.Host.smart_get(kwargs['id']) 1056 if host and host.shard: 1057 shard_hostname = host.shard.hostname 1058 ret = func(**kwargs) 1059 if shard_hostname and not server_utils.is_shard(): 1060 run_rpc_on_multiple_hostnames(func.__name__, [shard_hostname], 1061 **kwargs) 1062 return ret 1063 1064 return replacement 1065 1066 1067def fanout_rpc(host_objs, rpc_name, include_hostnames=True, **kwargs): 1068 """Fanout the given rpc to shards of given hosts. 1069 1070 @param host_objs: Host objects for the rpc. 1071 @param rpc_name: The name of the rpc. 1072 @param include_hostnames: If True, include the hostnames in the kwargs. 1073 Hostnames are not always necessary, this functions is designed to 1074 send rpcs to the shard a host is on, the rpcs themselves could be 1075 related to labels, acls etc. 1076 @param kwargs: The kwargs for the rpc. 1077 """ 1078 # Figure out which hosts are on which shards. 1079 shard_host_map = bucket_hosts_by_shard(host_objs) 1080 1081 # Execute the rpc against the appropriate shards. 1082 for shard, hostnames in six.iteritems(shard_host_map): 1083 if include_hostnames: 1084 kwargs['hosts'] = hostnames 1085 try: 1086 run_rpc_on_multiple_hostnames(rpc_name, [shard], **kwargs) 1087 except Exception as e: 1088 raise error.RPCException('RPC %s failed on shard %s due to %s' % 1089 (rpc_name, shard, e)) 1090 1091 1092def run_rpc_on_multiple_hostnames(rpc_call, shard_hostnames, **kwargs): 1093 """Runs an rpc to multiple AFEs 1094 1095 This is i.e. used to propagate changes made to hosts after they are assigned 1096 to a shard. 1097 1098 @param rpc_call: Name of the rpc endpoint to call. 1099 @param shard_hostnames: List of hostnames to run the rpcs on. 1100 @param **kwargs: Keyword arguments to pass in the rpcs. 1101 """ 1102 # Make sure this function is not called on shards but only on main. 1103 assert not server_utils.is_shard() 1104 for shard_hostname in shard_hostnames: 1105 afe = frontend_wrappers.RetryingAFE(server=shard_hostname, 1106 user=thread_local.get_user()) 1107 afe.run(rpc_call, **kwargs) 1108 1109 1110def get_label(name): 1111 """Gets a label object using a given name. 1112 1113 @param name: Label name. 1114 @raises model.Label.DoesNotExist: when there is no label matching 1115 the given name. 1116 @return: a label object matching the given name. 1117 """ 1118 try: 1119 label = models.Label.smart_get(name) 1120 except models.Label.DoesNotExist: 1121 return None 1122 return label 1123 1124 1125# TODO: hide the following rpcs under is_moblab 1126def moblab_only(func): 1127 """Ensure moblab specific functions only run on Moblab devices.""" 1128 def verify(*args, **kwargs): 1129 if not server_utils.is_moblab(): 1130 raise error.RPCException('RPC: %s can only run on Moblab Systems!', 1131 func.__name__) 1132 return func(*args, **kwargs) 1133 return verify 1134 1135 1136def route_rpc_to_main(func): 1137 """Route RPC to main AFE. 1138 1139 When a shard receives an RPC decorated by this, the RPC is just 1140 forwarded to the main. 1141 When the main gets the RPC, the RPC function is executed. 1142 1143 @param func: An RPC function to decorate 1144 1145 @returns: A function replacing the RPC func. 1146 """ 1147 argspec = inspect.getargspec(func) 1148 if argspec.varargs is not None: 1149 raise Exception('RPC function must not have *args.') 1150 1151 @wraps(func) 1152 def replacement(*args, **kwargs): 1153 """We need special handling when decorating an RPC that can be called 1154 directly using positional arguments. 1155 1156 One example is rpc_interface.create_job(). 1157 rpc_interface.create_job_page_handler() calls the function using both 1158 positional and keyword arguments. Since frontend.RpcClient.run() 1159 takes only keyword arguments for an RPC, positional arguments of the 1160 RPC function need to be transformed into keyword arguments. 1161 """ 1162 kwargs = _convert_to_kwargs_only(func, args, kwargs) 1163 if server_utils.is_shard(): 1164 afe = frontend_wrappers.RetryingAFE( 1165 server=server_utils.get_global_afe_hostname(), 1166 user=thread_local.get_user()) 1167 return afe.run(func.__name__, **kwargs) 1168 return func(**kwargs) 1169 1170 return replacement 1171 1172 1173def _convert_to_kwargs_only(func, args, kwargs): 1174 """Convert a function call's arguments to a kwargs dict. 1175 1176 This is best illustrated with an example. Given: 1177 1178 def foo(a, b, **kwargs): 1179 pass 1180 _to_kwargs(foo, (1, 2), {'c': 3}) # corresponding to foo(1, 2, c=3) 1181 1182 foo(**kwargs) 1183 1184 @param func: function whose signature to use 1185 @param args: positional arguments of call 1186 @param kwargs: keyword arguments of call 1187 1188 @returns: kwargs dict 1189 """ 1190 argspec = inspect.getargspec(func) 1191 # callargs looks like {'a': 1, 'b': 2, 'kwargs': {'c': 3}} 1192 callargs = inspect.getcallargs(func, *args, **kwargs) 1193 if argspec.keywords is None: 1194 kwargs = {} 1195 else: 1196 kwargs = callargs.pop(argspec.keywords) 1197 kwargs.update(callargs) 1198 return kwargs 1199 1200 1201def get_sample_dut(board, pool): 1202 """Get a dut with the given board and pool. 1203 1204 This method is used to help to locate a dut with the given board and pool. 1205 The dut then can be used to identify a devserver in the same subnet. 1206 1207 @param board: Name of the board. 1208 @param pool: Name of the pool. 1209 1210 @return: Name of a dut with the given board and pool. 1211 """ 1212 if not (dev_server.PREFER_LOCAL_DEVSERVER and pool and board): 1213 return None 1214 1215 hosts = list(get_host_query( 1216 multiple_labels=('pool:%s' % pool, 'board:%s' % board), 1217 exclude_only_if_needed_labels=False, 1218 valid_only=True, 1219 filter_data={}, 1220 )) 1221 if not hosts: 1222 return None 1223 else: 1224 return hosts[0].hostname 1225