• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# pylint: disable=missing-docstring
2"""
3Utility functions for rpc_interface.py.  We keep them in a separate file so that
4only RPC interface functions go into that file.
5"""
6
7__author__ = 'showard@google.com (Steve Howard)'
8
9import collections
10import datetime
11import inspect
12import logging
13import os
14from functools import wraps
15
16import django.db.utils
17import django.http
18import six
19from autotest_lib.client.common_lib import (control_data, error, global_config,
20                                            time_utils)
21from autotest_lib.client.common_lib.cros import dev_server
22from autotest_lib.frontend import thread_local
23from autotest_lib.frontend.afe import model_logic, models
24from autotest_lib.server import utils as server_utils
25from autotest_lib.server.cros import provision
26from autotest_lib.server.cros.dynamic_suite import frontend_wrappers
27
28NULL_DATETIME = datetime.datetime.max
29NULL_DATE = datetime.date.max
30DUPLICATE_KEY_MSG = 'Duplicate entry'
31RESPECT_STATIC_LABELS = global_config.global_config.get_config_value(
32        'SKYLAB', 'respect_static_labels', type=bool, default=False)
33
34def prepare_for_serialization(objects):
35    """
36    Prepare Python objects to be returned via RPC.
37    @param objects: objects to be prepared.
38    """
39    if (isinstance(objects, list) and len(objects) and
40        isinstance(objects[0], dict) and 'id' in objects[0]):
41        objects = _gather_unique_dicts(objects)
42    return _prepare_data(objects)
43
44
45def prepare_rows_as_nested_dicts(query, nested_dict_column_names):
46    """
47    Prepare a Django query to be returned via RPC as a sequence of nested
48    dictionaries.
49
50    @param query - A Django model query object with a select_related() method.
51    @param nested_dict_column_names - A list of column/attribute names for the
52            rows returned by query to expand into nested dictionaries using
53            their get_object_dict() method when not None.
54
55    @returns An list suitable to returned in an RPC.
56    """
57    all_dicts = []
58    for row in query.select_related():
59        row_dict = row.get_object_dict()
60        for column in nested_dict_column_names:
61            if row_dict[column] is not None:
62                row_dict[column] = getattr(row, column).get_object_dict()
63        all_dicts.append(row_dict)
64    return prepare_for_serialization(all_dicts)
65
66
67def _prepare_data(data):
68    """
69    Recursively process data structures, performing necessary type
70    conversions to values in data to allow for RPC serialization:
71    -convert datetimes to strings
72    -convert tuples and sets to lists
73    """
74    if isinstance(data, dict):
75        new_data = {}
76        for key, value in six.iteritems(data):
77            new_data[key] = _prepare_data(value)
78        return new_data
79    elif (isinstance(data, list) or isinstance(data, tuple) or
80          isinstance(data, set)):
81        return [_prepare_data(item) for item in data]
82    elif isinstance(data, datetime.date):
83        if data is NULL_DATETIME or data is NULL_DATE:
84            return None
85        return str(data)
86    else:
87        return data
88
89
90def fetchall_as_list_of_dicts(cursor):
91    """
92    Converts each row in the cursor to a dictionary so that values can be read
93    by using the column name.
94    @param cursor: The database cursor to read from.
95    @returns: A list of each row in the cursor as a dictionary.
96    """
97    desc = cursor.description
98    return [ dict(zip([col[0] for col in desc], row))
99             for row in cursor.fetchall() ]
100
101
102def raw_http_response(response_data, content_type=None):
103    response = django.http.HttpResponse(response_data, mimetype=content_type)
104    response['Content-length'] = str(len(response.content))
105    return response
106
107
108def _gather_unique_dicts(dict_iterable):
109    """\
110    Pick out unique objects (by ID) from an iterable of object dicts.
111    """
112    objects = collections.OrderedDict()
113    for obj in dict_iterable:
114        objects.setdefault(obj['id'], obj)
115    return list(objects.values())
116
117
118def extra_job_status_filters(not_yet_run=False, running=False, finished=False):
119    """\
120    Generate a SQL WHERE clause for job status filtering, and return it in
121    a dict of keyword args to pass to query.extra().
122    * not_yet_run: all HQEs are Queued
123    * finished: all HQEs are complete
124    * running: everything else
125    """
126    if not (not_yet_run or running or finished):
127        return {}
128    not_queued = ('(SELECT job_id FROM afe_host_queue_entries '
129                  'WHERE status != "%s")'
130                  % models.HostQueueEntry.Status.QUEUED)
131    not_finished = ('(SELECT job_id FROM afe_host_queue_entries '
132                    'WHERE not complete)')
133
134    where = []
135    if not_yet_run:
136        where.append('id NOT IN ' + not_queued)
137    if running:
138        where.append('(id IN %s) AND (id IN %s)' % (not_queued, not_finished))
139    if finished:
140        where.append('id NOT IN ' + not_finished)
141    return {'where': [' OR '.join(['(%s)' % x for x in where])]}
142
143
144def extra_job_type_filters(extra_args, suite=False,
145                           sub=False, standalone=False):
146    """\
147    Generate a SQL WHERE clause for job status filtering, and return it in
148    a dict of keyword args to pass to query.extra().
149
150    param extra_args: a dict of existing extra_args.
151
152    No more than one of the parameters should be passed as True:
153    * suite: job which is parent of other jobs
154    * sub: job with a parent job
155    * standalone: job with no child or parent jobs
156    """
157    assert not ((suite and sub) or
158                (suite and standalone) or
159                (sub and standalone)), ('Cannot specify more than one '
160                                        'filter to this function')
161
162    where = extra_args.get('where', [])
163    parent_job_id = ('DISTINCT parent_job_id')
164    child_job_id = ('id')
165    filter_common = ('(SELECT %s FROM afe_jobs '
166                     'WHERE parent_job_id IS NOT NULL)')
167
168    if suite:
169        where.append('id IN ' + filter_common % parent_job_id)
170    elif sub:
171        where.append('id IN ' + filter_common % child_job_id)
172    elif standalone:
173        where.append('NOT EXISTS (SELECT 1 from afe_jobs AS sub_query '
174                     'WHERE parent_job_id IS NOT NULL'
175                     ' AND (sub_query.parent_job_id=afe_jobs.id'
176                     ' OR sub_query.id=afe_jobs.id))')
177    else:
178        return extra_args
179
180    extra_args['where'] = where
181    return extra_args
182
183
184def get_host_query(multiple_labels, exclude_only_if_needed_labels,
185                   valid_only, filter_data):
186    """
187    @param exclude_only_if_needed_labels: Deprecated. By default it's false.
188    """
189    if valid_only:
190        initial_query = models.Host.valid_objects.all()
191    else:
192        initial_query = models.Host.objects.all()
193
194    try:
195        hosts = models.Host.get_hosts_with_labels(
196                multiple_labels, initial_query)
197        if not hosts:
198            return hosts
199
200        return models.Host.query_objects(filter_data, initial_query=hosts)
201    except models.Label.DoesNotExist:
202        return models.Host.objects.none()
203
204
205class InconsistencyException(Exception):
206    'Raised when a list of objects does not have a consistent value'
207
208
209def get_consistent_value(objects, field):
210    if not objects:
211        # well a list of nothing is consistent
212        return None
213
214    value = getattr(objects[0], field)
215    for obj in objects:
216        this_value = getattr(obj, field)
217        if this_value != value:
218            raise InconsistencyException(objects[0], obj)
219    return value
220
221
222def afe_test_dict_to_test_object(test_dict):
223    if not isinstance(test_dict, dict):
224        return test_dict
225
226    numerized_dict = {}
227    for key, value in six.iteritems(test_dict):
228        try:
229            numerized_dict[key] = int(value)
230        except (ValueError, TypeError):
231            numerized_dict[key] = value
232
233    return type('TestObject', (object,), numerized_dict)
234
235
236def _check_is_server_test(test_type):
237    """Checks if the test type is a server test.
238
239    @param test_type The test type in enum integer or string.
240
241    @returns A boolean to identify if the test type is server test.
242    """
243    if test_type is not None:
244        if isinstance(test_type, six.string_types):
245            try:
246                test_type = control_data.CONTROL_TYPE.get_value(test_type)
247            except AttributeError:
248                return False
249        return (test_type == control_data.CONTROL_TYPE.SERVER)
250    return False
251
252
253def prepare_generate_control_file(tests, profilers, db_tests=True):
254    if db_tests:
255        test_objects = [models.Test.smart_get(test) for test in tests]
256    else:
257        test_objects = [afe_test_dict_to_test_object(test) for test in tests]
258
259    profiler_objects = [models.Profiler.smart_get(profiler)
260                        for profiler in profilers]
261    # ensure tests are all the same type
262    try:
263        test_type = get_consistent_value(test_objects, 'test_type')
264    except InconsistencyException as exc:
265        test1, test2 = exc.args
266        raise model_logic.ValidationError(
267            {'tests' : 'You cannot run both test_suites and server-side '
268             'tests together (tests %s and %s differ' % (
269            test1.name, test2.name)})
270
271    is_server = _check_is_server_test(test_type)
272    if test_objects:
273        synch_count = max(test.sync_count for test in test_objects)
274    else:
275        synch_count = 1
276
277    if db_tests:
278        dependencies = set(label.name for label
279                           in models.Label.objects.filter(test__in=test_objects))
280    else:
281        dependencies = reduce(
282                set.union, [set(test.dependencies) for test in test_objects])
283
284    cf_info = dict(is_server=is_server, synch_count=synch_count,
285                   dependencies=list(dependencies))
286    return cf_info, test_objects, profiler_objects
287
288
289def check_job_dependencies(host_objects, job_dependencies):
290    """
291    Check that a set of machines satisfies a job's dependencies.
292    host_objects: list of models.Host objects
293    job_dependencies: list of names of labels
294    """
295    # check that hosts satisfy dependencies
296    host_ids = [host.id for host in host_objects]
297    hosts_in_job = models.Host.objects.filter(id__in=host_ids)
298    ok_hosts = hosts_in_job
299    for index, dependency in enumerate(job_dependencies):
300        if not provision.is_for_special_action(dependency):
301            try:
302                label = models.Label.smart_get(dependency)
303            except models.Label.DoesNotExist:
304                logging.info(
305                        'Label %r does not exist, so it cannot '
306                        'be replaced by static label.', dependency)
307                label = None
308
309            if label is not None and label.is_replaced_by_static():
310                ok_hosts = ok_hosts.filter(static_labels__name=dependency)
311            else:
312                ok_hosts = ok_hosts.filter(labels__name=dependency)
313
314    failing_hosts = (set(host.hostname for host in host_objects) -
315                     set(host.hostname for host in ok_hosts))
316    if failing_hosts:
317        raise model_logic.ValidationError(
318            {'hosts' : 'Host(s) failed to meet job dependencies (' +
319                       (', '.join(job_dependencies)) + '): ' +
320                       (', '.join(failing_hosts))})
321
322
323def check_job_metahost_dependencies(metahost_objects, job_dependencies):
324    """
325    Check that at least one machine within the metahost spec satisfies the job's
326    dependencies.
327
328    @param metahost_objects A list of label objects representing the metahosts.
329    @param job_dependencies A list of strings of the required label names.
330    @raises NoEligibleHostException If a metahost cannot run the job.
331    """
332    for metahost in metahost_objects:
333        if metahost.is_replaced_by_static():
334            static_metahost = models.StaticLabel.smart_get(metahost.name)
335            hosts = models.Host.objects.filter(static_labels=static_metahost)
336        else:
337            hosts = models.Host.objects.filter(labels=metahost)
338
339        for label_name in job_dependencies:
340            if not provision.is_for_special_action(label_name):
341                try:
342                    label = models.Label.smart_get(label_name)
343                except models.Label.DoesNotExist:
344                    logging.info('Label %r does not exist, so it cannot '
345                                 'be replaced by static label.', label_name)
346                    label = None
347
348                if label is not None and label.is_replaced_by_static():
349                    hosts = hosts.filter(static_labels__name=label_name)
350                else:
351                    hosts = hosts.filter(labels__name=label_name)
352
353        if not any(hosts):
354            raise error.NoEligibleHostException("No hosts within %s satisfy %s."
355                    % (metahost.name, ', '.join(job_dependencies)))
356
357
358def _execution_key_for(host_queue_entry):
359    return (host_queue_entry.job.id, host_queue_entry.execution_subdir)
360
361
362def check_abort_synchronous_jobs(host_queue_entries):
363    # ensure user isn't aborting part of a synchronous autoserv execution
364    count_per_execution = {}
365    for queue_entry in host_queue_entries:
366        key = _execution_key_for(queue_entry)
367        count_per_execution.setdefault(key, 0)
368        count_per_execution[key] += 1
369
370    for queue_entry in host_queue_entries:
371        if not queue_entry.execution_subdir:
372            continue
373        execution_count = count_per_execution[_execution_key_for(queue_entry)]
374        if execution_count < queue_entry.job.synch_count:
375            raise model_logic.ValidationError(
376                {'' : 'You cannot abort part of a synchronous job execution '
377                      '(%d/%s), %d included, %d expected'
378                      % (queue_entry.job.id, queue_entry.execution_subdir,
379                         execution_count, queue_entry.job.synch_count)})
380
381
382def check_modify_host(update_data):
383    """
384    Check modify_host* requests.
385
386    @param update_data: A dictionary with the changes to make to a host
387            or hosts.
388    """
389    # Only the scheduler (monitor_db) is allowed to modify Host status.
390    # Otherwise race conditions happen as a hosts state is changed out from
391    # beneath tasks being run on a host.
392    if 'status' in update_data:
393        raise model_logic.ValidationError({
394                'status': 'Host status can not be modified by the frontend.'})
395
396
397def check_modify_host_locking(host, update_data):
398    """
399    Checks when locking/unlocking has been requested if the host is already
400    locked/unlocked.
401
402    @param host: models.Host object to be modified
403    @param update_data: A dictionary with the changes to make to the host.
404    """
405    locked = update_data.get('locked', None)
406    lock_reason = update_data.get('lock_reason', None)
407    if locked is not None:
408        if locked and host.locked:
409            raise model_logic.ValidationError({
410                    'locked': 'Host %s already locked by %s on %s.' %
411                    (host.hostname, host.locked_by, host.lock_time)})
412        if not locked and not host.locked:
413            raise model_logic.ValidationError({
414                    'locked': 'Host %s already unlocked.' % host.hostname})
415        if locked and not lock_reason and not host.locked:
416            raise model_logic.ValidationError({
417                    'locked': 'Please provide a reason for locking Host %s' %
418                    host.hostname})
419
420
421def get_motd():
422    dirname = os.path.dirname(__file__)
423    filename = os.path.join(dirname, "..", "..", "motd.txt")
424    text = ''
425    try:
426        fp = open(filename, "r")
427        try:
428            text = fp.read()
429        finally:
430            fp.close()
431    except:
432        pass
433
434    return text
435
436
437def _get_metahost_counts(metahost_objects):
438    metahost_counts = {}
439    for metahost in metahost_objects:
440        metahost_counts.setdefault(metahost, 0)
441        metahost_counts[metahost] += 1
442    return metahost_counts
443
444
445def get_job_info(job, preserve_metahosts=False, queue_entry_filter_data=None):
446    hosts = []
447    one_time_hosts = []
448    meta_hosts = []
449    hostless = False
450
451    queue_entries = job.hostqueueentry_set.all()
452    if queue_entry_filter_data:
453        queue_entries = models.HostQueueEntry.query_objects(
454            queue_entry_filter_data, initial_query=queue_entries)
455
456    for queue_entry in queue_entries:
457        if (queue_entry.host and (preserve_metahosts or
458                                  not queue_entry.meta_host)):
459            if queue_entry.deleted:
460                continue
461            if queue_entry.host.invalid:
462                one_time_hosts.append(queue_entry.host)
463            else:
464                hosts.append(queue_entry.host)
465        elif queue_entry.meta_host:
466            meta_hosts.append(queue_entry.meta_host)
467        else:
468            hostless = True
469
470    meta_host_counts = _get_metahost_counts(meta_hosts)
471
472    info = dict(dependencies=[label.name for label
473                              in job.dependency_labels.all()],
474                hosts=hosts,
475                meta_hosts=meta_hosts,
476                meta_host_counts=meta_host_counts,
477                one_time_hosts=one_time_hosts,
478                hostless=hostless)
479    return info
480
481
482def check_for_duplicate_hosts(host_objects):
483    host_counts = collections.Counter(host_objects)
484    duplicate_hostnames = {
485            host.hostname
486            for host, count in six.iteritems(host_counts) if count > 1
487    }
488    if duplicate_hostnames:
489        raise model_logic.ValidationError(
490                {'hosts' : 'Duplicate hosts: %s'
491                 % ', '.join(duplicate_hostnames)})
492
493
494def create_new_job(owner, options, host_objects, metahost_objects):
495    all_host_objects = host_objects + metahost_objects
496    dependencies = options.get('dependencies', [])
497    synch_count = options.get('synch_count')
498
499    if synch_count is not None and synch_count > len(all_host_objects):
500        raise model_logic.ValidationError(
501                {'hosts':
502                 'only %d hosts provided for job with synch_count = %d' %
503                 (len(all_host_objects), synch_count)})
504
505    check_for_duplicate_hosts(host_objects)
506
507    for label_name in dependencies:
508        if provision.is_for_special_action(label_name):
509            # TODO: We could save a few queries
510            # if we had a bulk ensure-label-exists function, which used
511            # a bulk .get() call. The win is probably very small.
512            _ensure_label_exists(label_name)
513
514    # This only checks targeted hosts, not hosts eligible due to the metahost
515    check_job_dependencies(host_objects, dependencies)
516    check_job_metahost_dependencies(metahost_objects, dependencies)
517
518    options['dependencies'] = list(
519            models.Label.objects.filter(name__in=dependencies))
520
521    job = models.Job.create(owner=owner, options=options,
522                            hosts=all_host_objects)
523    job.queue(all_host_objects,
524              is_template=options.get('is_template', False))
525    return job.id
526
527
528def _ensure_label_exists(name):
529    """
530    Ensure that a label called |name| exists in the Django models.
531
532    This function is to be called from within afe rpcs only, as an
533    alternative to server.cros.provision.ensure_label_exists(...). It works
534    by Django model manipulation, rather than by making another create_label
535    rpc call.
536
537    @param name: the label to check for/create.
538    @raises ValidationError: There was an error in the response that was
539                             not because the label already existed.
540    @returns True is a label was created, False otherwise.
541    """
542    # Make sure this function is not called on shards but only on main.
543    assert not server_utils.is_shard()
544    try:
545        models.Label.objects.get(name=name)
546    except models.Label.DoesNotExist:
547        try:
548            new_label = models.Label.objects.create(name=name)
549            new_label.save()
550            return True
551        except django.db.utils.IntegrityError as e:
552            # It is possible that another suite/test already
553            # created the label between the check and save.
554            if DUPLICATE_KEY_MSG in str(e):
555                return False
556            else:
557                raise
558    return False
559
560
561def find_platform(hostname, label_list):
562    """
563    Figure out the platform name for the given host
564    object.  If none, the return value for either will be None.
565
566    @param hostname: The hostname to find platform.
567    @param label_list: The label list to find platform.
568
569    @returns platform name for the given host.
570    """
571    platforms = [label.name for label in label_list if label.platform]
572    if not platforms:
573        platform = None
574    else:
575        platform = platforms[0]
576
577    if len(platforms) > 1:
578        raise ValueError('Host %s has more than one platform: %s' %
579                         (hostname, ', '.join(platforms)))
580
581    return platform
582
583
584# support for get_host_queue_entries_and_special_tasks()
585
586def _common_entry_to_dict(entry, type, job_dict, exec_path, status, started_on):
587    return dict(type=type,
588                host=entry['host'],
589                job=job_dict,
590                execution_path=exec_path,
591                status=status,
592                started_on=started_on,
593                id=str(entry['id']) + type,
594                oid=entry['id'])
595
596
597def _special_task_to_dict(task, queue_entries):
598    """Transforms a special task dictionary to another form of dictionary.
599
600    @param task           Special task as a dictionary type
601    @param queue_entries  Host queue entries as a list of dictionaries.
602
603    @return Transformed dictionary for a special task.
604    """
605    job_dict = None
606    if task['queue_entry']:
607        # Scan queue_entries to get the job detail info.
608        for qentry in queue_entries:
609            if task['queue_entry']['id'] == qentry['id']:
610                job_dict = qentry['job']
611                break
612        # If not found, get it from DB.
613        if job_dict is None:
614            job = models.Job.objects.get(id=task['queue_entry']['job'])
615            job_dict = job.get_object_dict()
616
617    exec_path = server_utils.get_special_task_exec_path(
618            task['host']['hostname'], task['id'], task['task'],
619            time_utils.time_string_to_datetime(task['time_requested']))
620    status = server_utils.get_special_task_status(
621            task['is_complete'], task['success'], task['is_active'])
622    return _common_entry_to_dict(task, task['task'], job_dict,
623            exec_path, status, task['time_started'])
624
625
626def _queue_entry_to_dict(queue_entry):
627    job_dict = queue_entry['job']
628    tag = server_utils.get_job_tag(job_dict['id'], job_dict['owner'])
629    exec_path = server_utils.get_hqe_exec_path(tag,
630                                               queue_entry['execution_subdir'])
631    return _common_entry_to_dict(queue_entry, 'Job', job_dict, exec_path,
632            queue_entry['status'], queue_entry['started_on'])
633
634
635def prepare_host_queue_entries_and_special_tasks(interleaved_entries,
636                                                 queue_entries):
637    """
638    Prepare for serialization the interleaved entries of host queue entries
639    and special tasks.
640    Each element in the entries is a dictionary type.
641    The special task dictionary has only a job id for a job and lacks
642    the detail of the job while the host queue entry dictionary has.
643    queue_entries is used to look up the job detail info.
644
645    @param interleaved_entries  Host queue entries and special tasks as a list
646                                of dictionaries.
647    @param queue_entries        Host queue entries as a list of dictionaries.
648
649    @return A post-processed list of dictionaries that is to be serialized.
650    """
651    dict_list = []
652    for e in interleaved_entries:
653        # Distinguish the two mixed entries based on the existence of
654        # the key "task". If an entry has the key, the entry is for
655        # special task. Otherwise, host queue entry.
656        if 'task' in e:
657            dict_list.append(_special_task_to_dict(e, queue_entries))
658        else:
659            dict_list.append(_queue_entry_to_dict(e))
660    return prepare_for_serialization(dict_list)
661
662
663def _compute_next_job_for_tasks(queue_entries, special_tasks):
664    """
665    For each task, try to figure out the next job that ran after that task.
666    This is done using two pieces of information:
667    * if the task has a queue entry, we can use that entry's job ID.
668    * if the task has a time_started, we can try to compare that against the
669      started_on field of queue_entries. this isn't guaranteed to work perfectly
670      since queue_entries may also have null started_on values.
671    * if the task has neither, or if use of time_started fails, just use the
672      last computed job ID.
673
674    @param queue_entries    Host queue entries as a list of dictionaries.
675    @param special_tasks    Special tasks as a list of dictionaries.
676    """
677    next_job_id = None # most recently computed next job
678    hqe_index = 0 # index for scanning by started_on times
679    for task in special_tasks:
680        if task['queue_entry']:
681            next_job_id = task['queue_entry']['job']
682        elif task['time_started'] is not None:
683            for queue_entry in queue_entries[hqe_index:]:
684                if queue_entry['started_on'] is None:
685                    continue
686                t1 = time_utils.time_string_to_datetime(
687                        queue_entry['started_on'])
688                t2 = time_utils.time_string_to_datetime(task['time_started'])
689                if t1 < t2:
690                    break
691                next_job_id = queue_entry['job']['id']
692
693        task['next_job_id'] = next_job_id
694
695        # advance hqe_index to just after next_job_id
696        if next_job_id is not None:
697            for queue_entry in queue_entries[hqe_index:]:
698                if queue_entry['job']['id'] < next_job_id:
699                    break
700                hqe_index += 1
701
702
703def interleave_entries(queue_entries, special_tasks):
704    """
705    Both lists should be ordered by descending ID.
706    """
707    _compute_next_job_for_tasks(queue_entries, special_tasks)
708
709    # start with all special tasks that've run since the last job
710    interleaved_entries = []
711    for task in special_tasks:
712        if task['next_job_id'] is not None:
713            break
714        interleaved_entries.append(task)
715
716    # now interleave queue entries with the remaining special tasks
717    special_task_index = len(interleaved_entries)
718    for queue_entry in queue_entries:
719        interleaved_entries.append(queue_entry)
720        # add all tasks that ran between this job and the previous one
721        for task in special_tasks[special_task_index:]:
722            if task['next_job_id'] < queue_entry['job']['id']:
723                break
724            interleaved_entries.append(task)
725            special_task_index += 1
726
727    return interleaved_entries
728
729
730def bucket_hosts_by_shard(host_objs):
731    """Figure out which hosts are on which shards.
732
733    @param host_objs: A list of host objects.
734
735    @return: A map of shard hostname: list of hosts on the shard.
736    """
737    shard_host_map = collections.defaultdict(list)
738    for host in host_objs:
739        if host.shard:
740            shard_host_map[host.shard.hostname].append(host.hostname)
741    return shard_host_map
742
743
744def create_job_common(
745        name,
746        priority,
747        control_type,
748        control_file=None,
749        hosts=(),
750        meta_hosts=(),
751        one_time_hosts=(),
752        synch_count=None,
753        is_template=False,
754        timeout=None,
755        timeout_mins=None,
756        max_runtime_mins=None,
757        run_verify=True,
758        email_list='',
759        dependencies=(),
760        reboot_before=None,
761        reboot_after=None,
762        parse_failed_repair=None,
763        hostless=False,
764        keyvals=None,
765        drone_set=None,
766        parent_job_id=None,
767        run_reset=True,
768        require_ssp=None):
769    #pylint: disable-msg=C0111
770    """
771    Common code between creating "standard" jobs and creating parameterized jobs
772    """
773    # input validation
774    host_args_passed = any((hosts, meta_hosts, one_time_hosts))
775    if hostless:
776        if host_args_passed:
777            raise model_logic.ValidationError({
778                    'hostless': 'Hostless jobs cannot include any hosts!'})
779        if control_type != control_data.CONTROL_TYPE_NAMES.SERVER:
780            raise model_logic.ValidationError({
781                    'control_type': 'Hostless jobs cannot use client-side '
782                                    'control files'})
783    elif not host_args_passed:
784        raise model_logic.ValidationError({
785            'arguments' : "For host jobs, you must pass at least one of"
786                          " 'hosts', 'meta_hosts', 'one_time_hosts'."
787            })
788    label_objects = list(models.Label.objects.filter(name__in=meta_hosts))
789
790    # convert hostnames & meta hosts to host/label objects
791    host_objects = models.Host.smart_get_bulk(hosts)
792    _validate_host_job_sharding(host_objects)
793    for host in one_time_hosts:
794        this_host = models.Host.create_one_time_host(host)
795        host_objects.append(this_host)
796
797    metahost_objects = []
798    meta_host_labels_by_name = {label.name: label for label in label_objects}
799    for label_name in meta_hosts:
800        if label_name in meta_host_labels_by_name:
801            metahost_objects.append(meta_host_labels_by_name[label_name])
802        else:
803            raise model_logic.ValidationError(
804                {'meta_hosts' : 'Label "%s" not found' % label_name})
805
806    options = dict(name=name,
807                   priority=priority,
808                   control_file=control_file,
809                   control_type=control_type,
810                   is_template=is_template,
811                   timeout=timeout,
812                   timeout_mins=timeout_mins,
813                   max_runtime_mins=max_runtime_mins,
814                   synch_count=synch_count,
815                   run_verify=run_verify,
816                   email_list=email_list,
817                   dependencies=dependencies,
818                   reboot_before=reboot_before,
819                   reboot_after=reboot_after,
820                   parse_failed_repair=parse_failed_repair,
821                   keyvals=keyvals,
822                   drone_set=drone_set,
823                   parent_job_id=parent_job_id,
824                   # TODO(crbug.com/873716) DEPRECATED. Remove entirely.
825                   test_retry=0,
826                   run_reset=run_reset,
827                   require_ssp=require_ssp)
828
829    return create_new_job(owner=models.User.current_user().login,
830                          options=options,
831                          host_objects=host_objects,
832                          metahost_objects=metahost_objects)
833
834
835def _validate_host_job_sharding(host_objects):
836    """Check that the hosts obey job sharding rules."""
837    if not (server_utils.is_shard()
838            or _allowed_hosts_for_main_job(host_objects)):
839        shard_host_map = bucket_hosts_by_shard(host_objects)
840        raise ValueError(
841                'The following hosts are on shard(s), please create '
842                'seperate jobs for hosts on each shard: %s ' %
843                shard_host_map)
844
845
846def _allowed_hosts_for_main_job(host_objects):
847    """Check that the hosts are allowed for a job on main."""
848    # We disallow the following jobs on main:
849    #   num_shards > 1: this is a job spanning across multiple shards.
850    #   num_shards == 1 but number of hosts on shard is less
851    #   than total number of hosts: this is a job that spans across
852    #   one shard and the main.
853    shard_host_map = bucket_hosts_by_shard(host_objects)
854    num_shards = len(shard_host_map)
855    if num_shards > 1:
856        return False
857    if num_shards == 1:
858        hosts_on_shard = list(shard_host_map.values())[0]
859        assert len(hosts_on_shard) <= len(host_objects)
860        return len(hosts_on_shard) == len(host_objects)
861    else:
862        return True
863
864
865def encode_ascii(control_file):
866    """Force a control file to only contain ascii characters.
867
868    @param control_file: Control file to encode.
869
870    @returns the control file in an ascii encoding.
871
872    @raises error.ControlFileMalformed: if encoding fails.
873    """
874    try:
875        return control_file.encode('ascii')
876    except UnicodeDecodeError as e:
877        raise error.ControlFileMalformed(str(e))
878
879
880def get_wmatrix_url():
881    """Get wmatrix url from config file.
882
883    @returns the wmatrix url or an empty string.
884    """
885    return global_config.global_config.get_config_value('AUTOTEST_WEB',
886                                                        'wmatrix_url',
887                                                        default='')
888
889
890def get_stainless_url():
891    """Get stainless url from config file.
892
893    @returns the stainless url or an empty string.
894    """
895    return global_config.global_config.get_config_value('AUTOTEST_WEB',
896                                                        'stainless_url',
897                                                        default='')
898
899
900def inject_times_to_filter(start_time_key=None, end_time_key=None,
901                         start_time_value=None, end_time_value=None,
902                         **filter_data):
903    """Inject the key value pairs of start and end time if provided.
904
905    @param start_time_key: A string represents the filter key of start_time.
906    @param end_time_key: A string represents the filter key of end_time.
907    @param start_time_value: Start_time value.
908    @param end_time_value: End_time value.
909
910    @returns the injected filter_data.
911    """
912    if start_time_value:
913        filter_data[start_time_key] = start_time_value
914    if end_time_value:
915        filter_data[end_time_key] = end_time_value
916    return filter_data
917
918
919def inject_times_to_hqe_special_tasks_filters(filter_data_common,
920                                              start_time, end_time):
921    """Inject start and end time to hqe and special tasks filters.
922
923    @param filter_data_common: Common filter for hqe and special tasks.
924    @param start_time_key: A string represents the filter key of start_time.
925    @param end_time_key: A string represents the filter key of end_time.
926
927    @returns a pair of hqe and special tasks filters.
928    """
929    filter_data_special_tasks = filter_data_common.copy()
930    return (inject_times_to_filter('started_on__gte', 'started_on__lte',
931                                   start_time, end_time, **filter_data_common),
932           inject_times_to_filter('time_started__gte', 'time_started__lte',
933                                  start_time, end_time,
934                                  **filter_data_special_tasks))
935
936
937def retrieve_shard(shard_hostname):
938    """
939    Retrieves the shard with the given hostname from the database.
940
941    @param shard_hostname: Hostname of the shard to retrieve
942
943    @raises models.Shard.DoesNotExist, if no shard with this hostname was found.
944
945    @returns: Shard object
946    """
947    return models.Shard.smart_get(shard_hostname)
948
949
950def find_records_for_shard(shard, known_job_ids, known_host_ids):
951    """Find records that should be sent to a shard.
952
953    @param shard: Shard to find records for.
954    @param known_job_ids: List of ids of jobs the shard already has.
955    @param known_host_ids: List of ids of hosts the shard already has.
956
957    @returns: Tuple of lists:
958              (hosts, jobs, suite_job_keyvals, invalid_host_ids)
959    """
960    hosts, invalid_host_ids = models.Host.assign_to_shard(
961            shard, known_host_ids)
962    jobs = models.Job.assign_to_shard(shard, known_job_ids)
963    parent_job_ids = [job.parent_job_id for job in jobs]
964    suite_job_keyvals = models.JobKeyval.objects.filter(
965            job_id__in=parent_job_ids)
966    return hosts, jobs, suite_job_keyvals, invalid_host_ids
967
968
969def _persist_records_with_type_sent_from_shard(
970    shard, records, record_type, *args, **kwargs):
971    """
972    Handle records of a specified type that were sent to the shard main.
973
974    @param shard: The shard the records were sent from.
975    @param records: The records sent in their serialized format.
976    @param record_type: Type of the objects represented by records.
977    @param args: Additional arguments that will be passed on to the checks.
978    @param kwargs: Additional arguments that will be passed on to the checks.
979
980    @raises error.UnallowedRecordsSentToMain if any of the checks fail.
981
982    @returns: List of primary keys of the processed records.
983    """
984    pks = []
985    for serialized_record in records:
986        pk = serialized_record['id']
987        try:
988            current_record = record_type.objects.get(pk=pk)
989        except record_type.DoesNotExist:
990            raise error.UnallowedRecordsSentToMain(
991                'Object with pk %s of type %s does not exist on main.' % (
992                    pk, record_type))
993
994        try:
995            current_record._check_update_from_shard(
996                shard, serialized_record, *args, **kwargs)
997        except error.IgnorableUnallowedRecordsSentToMain:
998            # An illegal record change was attempted, but it was of a non-fatal
999            # variety. Silently skip this record.
1000            pass
1001        else:
1002            current_record.update_from_serialized(serialized_record)
1003            pks.append(pk)
1004
1005    return pks
1006
1007
1008def persist_records_sent_from_shard(shard, jobs, hqes):
1009    """
1010    Checking then saving serialized records sent to main from shard.
1011
1012    During heartbeats shards upload jobs and hostqueuentries. This performs
1013    some checks on these and then updates the existing records for those
1014    entries with the updated ones from the heartbeat.
1015
1016    The checks include:
1017    - Checking if the objects sent already exist on the main.
1018    - Checking if the objects sent were assigned to this shard.
1019    - hostqueueentries must be sent together with their jobs.
1020
1021    @param shard: The shard the records were sent from.
1022    @param jobs: The jobs the shard sent.
1023    @param hqes: The hostqueuentries the shart sent.
1024
1025    @raises error.UnallowedRecordsSentToMain if any of the checks fail.
1026    """
1027    job_ids_persisted = _persist_records_with_type_sent_from_shard(
1028            shard, jobs, models.Job)
1029    _persist_records_with_type_sent_from_shard(
1030            shard, hqes, models.HostQueueEntry,
1031            job_ids_sent=job_ids_persisted)
1032
1033
1034def forward_single_host_rpc_to_shard(func):
1035    """This decorator forwards rpc calls that modify a host to a shard.
1036
1037    If a host is assigned to a shard, rpcs that change the host attributes should be
1038    forwarded to the shard.
1039
1040    This assumes the first argument of the function represents a host id.
1041
1042    @param func: The function to decorate
1043
1044    @returns: The function to replace func with.
1045    """
1046    def replacement(**kwargs):
1047        # Only keyword arguments can be accepted here, as we need the argument
1048        # names to send the rpc. serviceHandler always provides arguments with
1049        # their keywords, so this is not a problem.
1050
1051        # A host record (identified by kwargs['id']) can be deleted in
1052        # func(). Therefore, we should save the data that can be needed later
1053        # before func() is called.
1054        shard_hostname = None
1055        host = models.Host.smart_get(kwargs['id'])
1056        if host and host.shard:
1057            shard_hostname = host.shard.hostname
1058        ret = func(**kwargs)
1059        if shard_hostname and not server_utils.is_shard():
1060            run_rpc_on_multiple_hostnames(func.__name__, [shard_hostname],
1061                                          **kwargs)
1062        return ret
1063
1064    return replacement
1065
1066
1067def fanout_rpc(host_objs, rpc_name, include_hostnames=True, **kwargs):
1068    """Fanout the given rpc to shards of given hosts.
1069
1070    @param host_objs: Host objects for the rpc.
1071    @param rpc_name: The name of the rpc.
1072    @param include_hostnames: If True, include the hostnames in the kwargs.
1073        Hostnames are not always necessary, this functions is designed to
1074        send rpcs to the shard a host is on, the rpcs themselves could be
1075        related to labels, acls etc.
1076    @param kwargs: The kwargs for the rpc.
1077    """
1078    # Figure out which hosts are on which shards.
1079    shard_host_map = bucket_hosts_by_shard(host_objs)
1080
1081    # Execute the rpc against the appropriate shards.
1082    for shard, hostnames in six.iteritems(shard_host_map):
1083        if include_hostnames:
1084            kwargs['hosts'] = hostnames
1085        try:
1086            run_rpc_on_multiple_hostnames(rpc_name, [shard], **kwargs)
1087        except Exception as e:
1088            raise error.RPCException('RPC %s failed on shard %s due to %s' %
1089                                     (rpc_name, shard, e))
1090
1091
1092def run_rpc_on_multiple_hostnames(rpc_call, shard_hostnames, **kwargs):
1093    """Runs an rpc to multiple AFEs
1094
1095    This is i.e. used to propagate changes made to hosts after they are assigned
1096    to a shard.
1097
1098    @param rpc_call: Name of the rpc endpoint to call.
1099    @param shard_hostnames: List of hostnames to run the rpcs on.
1100    @param **kwargs: Keyword arguments to pass in the rpcs.
1101    """
1102    # Make sure this function is not called on shards but only on main.
1103    assert not server_utils.is_shard()
1104    for shard_hostname in shard_hostnames:
1105        afe = frontend_wrappers.RetryingAFE(server=shard_hostname,
1106                                            user=thread_local.get_user())
1107        afe.run(rpc_call, **kwargs)
1108
1109
1110def get_label(name):
1111    """Gets a label object using a given name.
1112
1113    @param name: Label name.
1114    @raises model.Label.DoesNotExist: when there is no label matching
1115                                      the given name.
1116    @return: a label object matching the given name.
1117    """
1118    try:
1119        label = models.Label.smart_get(name)
1120    except models.Label.DoesNotExist:
1121        return None
1122    return label
1123
1124
1125# TODO: hide the following rpcs under is_moblab
1126def moblab_only(func):
1127    """Ensure moblab specific functions only run on Moblab devices."""
1128    def verify(*args, **kwargs):
1129        if not server_utils.is_moblab():
1130            raise error.RPCException('RPC: %s can only run on Moblab Systems!',
1131                                     func.__name__)
1132        return func(*args, **kwargs)
1133    return verify
1134
1135
1136def route_rpc_to_main(func):
1137    """Route RPC to main AFE.
1138
1139    When a shard receives an RPC decorated by this, the RPC is just
1140    forwarded to the main.
1141    When the main gets the RPC, the RPC function is executed.
1142
1143    @param func: An RPC function to decorate
1144
1145    @returns: A function replacing the RPC func.
1146    """
1147    argspec = inspect.getargspec(func)
1148    if argspec.varargs is not None:
1149        raise Exception('RPC function must not have *args.')
1150
1151    @wraps(func)
1152    def replacement(*args, **kwargs):
1153        """We need special handling when decorating an RPC that can be called
1154        directly using positional arguments.
1155
1156        One example is rpc_interface.create_job().
1157        rpc_interface.create_job_page_handler() calls the function using both
1158        positional and keyword arguments.  Since frontend.RpcClient.run()
1159        takes only keyword arguments for an RPC, positional arguments of the
1160        RPC function need to be transformed into keyword arguments.
1161        """
1162        kwargs = _convert_to_kwargs_only(func, args, kwargs)
1163        if server_utils.is_shard():
1164            afe = frontend_wrappers.RetryingAFE(
1165                    server=server_utils.get_global_afe_hostname(),
1166                    user=thread_local.get_user())
1167            return afe.run(func.__name__, **kwargs)
1168        return func(**kwargs)
1169
1170    return replacement
1171
1172
1173def _convert_to_kwargs_only(func, args, kwargs):
1174    """Convert a function call's arguments to a kwargs dict.
1175
1176    This is best illustrated with an example.  Given:
1177
1178    def foo(a, b, **kwargs):
1179        pass
1180    _to_kwargs(foo, (1, 2), {'c': 3})  # corresponding to foo(1, 2, c=3)
1181
1182        foo(**kwargs)
1183
1184    @param func: function whose signature to use
1185    @param args: positional arguments of call
1186    @param kwargs: keyword arguments of call
1187
1188    @returns: kwargs dict
1189    """
1190    argspec = inspect.getargspec(func)
1191    # callargs looks like {'a': 1, 'b': 2, 'kwargs': {'c': 3}}
1192    callargs = inspect.getcallargs(func, *args, **kwargs)
1193    if argspec.keywords is None:
1194        kwargs = {}
1195    else:
1196        kwargs = callargs.pop(argspec.keywords)
1197    kwargs.update(callargs)
1198    return kwargs
1199
1200
1201def get_sample_dut(board, pool):
1202    """Get a dut with the given board and pool.
1203
1204    This method is used to help to locate a dut with the given board and pool.
1205    The dut then can be used to identify a devserver in the same subnet.
1206
1207    @param board: Name of the board.
1208    @param pool: Name of the pool.
1209
1210    @return: Name of a dut with the given board and pool.
1211    """
1212    if not (dev_server.PREFER_LOCAL_DEVSERVER and pool and board):
1213        return None
1214
1215    hosts = list(get_host_query(
1216        multiple_labels=('pool:%s' % pool, 'board:%s' % board),
1217        exclude_only_if_needed_labels=False,
1218        valid_only=True,
1219        filter_data={},
1220    ))
1221    if not hosts:
1222        return None
1223    else:
1224        return hosts[0].hostname
1225