• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# pylint: disable=missing-docstring
2"""
3Utility functions for rpc_interface.py.  We keep them in a separate file so that
4only RPC interface functions go into that file.
5"""
6
7__author__ = 'showard@google.com (Steve Howard)'
8
9import collections
10import datetime
11from functools import wraps
12import inspect
13import os
14import sys
15import django.db.utils
16import django.http
17
18from autotest_lib.frontend import thread_local
19from autotest_lib.frontend.afe import models, model_logic
20from autotest_lib.client.common_lib import control_data, error
21from autotest_lib.client.common_lib import global_config
22from autotest_lib.client.common_lib import time_utils
23from autotest_lib.client.common_lib.cros import dev_server
24from autotest_lib.server import utils as server_utils
25from autotest_lib.server.cros import provision
26from autotest_lib.server.cros.dynamic_suite import frontend_wrappers
27
28NULL_DATETIME = datetime.datetime.max
29NULL_DATE = datetime.date.max
30DUPLICATE_KEY_MSG = 'Duplicate entry'
31
32def prepare_for_serialization(objects):
33    """
34    Prepare Python objects to be returned via RPC.
35    @param objects: objects to be prepared.
36    """
37    if (isinstance(objects, list) and len(objects) and
38        isinstance(objects[0], dict) and 'id' in objects[0]):
39        objects = _gather_unique_dicts(objects)
40    return _prepare_data(objects)
41
42
43def prepare_rows_as_nested_dicts(query, nested_dict_column_names):
44    """
45    Prepare a Django query to be returned via RPC as a sequence of nested
46    dictionaries.
47
48    @param query - A Django model query object with a select_related() method.
49    @param nested_dict_column_names - A list of column/attribute names for the
50            rows returned by query to expand into nested dictionaries using
51            their get_object_dict() method when not None.
52
53    @returns An list suitable to returned in an RPC.
54    """
55    all_dicts = []
56    for row in query.select_related():
57        row_dict = row.get_object_dict()
58        for column in nested_dict_column_names:
59            if row_dict[column] is not None:
60                row_dict[column] = getattr(row, column).get_object_dict()
61        all_dicts.append(row_dict)
62    return prepare_for_serialization(all_dicts)
63
64
65def _prepare_data(data):
66    """
67    Recursively process data structures, performing necessary type
68    conversions to values in data to allow for RPC serialization:
69    -convert datetimes to strings
70    -convert tuples and sets to lists
71    """
72    if isinstance(data, dict):
73        new_data = {}
74        for key, value in data.iteritems():
75            new_data[key] = _prepare_data(value)
76        return new_data
77    elif (isinstance(data, list) or isinstance(data, tuple) or
78          isinstance(data, set)):
79        return [_prepare_data(item) for item in data]
80    elif isinstance(data, datetime.date):
81        if data is NULL_DATETIME or data is NULL_DATE:
82            return None
83        return str(data)
84    else:
85        return data
86
87
88def fetchall_as_list_of_dicts(cursor):
89    """
90    Converts each row in the cursor to a dictionary so that values can be read
91    by using the column name.
92    @param cursor: The database cursor to read from.
93    @returns: A list of each row in the cursor as a dictionary.
94    """
95    desc = cursor.description
96    return [ dict(zip([col[0] for col in desc], row))
97             for row in cursor.fetchall() ]
98
99
100def raw_http_response(response_data, content_type=None):
101    response = django.http.HttpResponse(response_data, mimetype=content_type)
102    response['Content-length'] = str(len(response.content))
103    return response
104
105
106def _gather_unique_dicts(dict_iterable):
107    """\
108    Pick out unique objects (by ID) from an iterable of object dicts.
109    """
110    objects = collections.OrderedDict()
111    for obj in dict_iterable:
112        objects.setdefault(obj['id'], obj)
113    return objects.values()
114
115
116def extra_job_status_filters(not_yet_run=False, running=False, finished=False):
117    """\
118    Generate a SQL WHERE clause for job status filtering, and return it in
119    a dict of keyword args to pass to query.extra().
120    * not_yet_run: all HQEs are Queued
121    * finished: all HQEs are complete
122    * running: everything else
123    """
124    if not (not_yet_run or running or finished):
125        return {}
126    not_queued = ('(SELECT job_id FROM afe_host_queue_entries '
127                  'WHERE status != "%s")'
128                  % models.HostQueueEntry.Status.QUEUED)
129    not_finished = ('(SELECT job_id FROM afe_host_queue_entries '
130                    'WHERE not complete)')
131
132    where = []
133    if not_yet_run:
134        where.append('id NOT IN ' + not_queued)
135    if running:
136        where.append('(id IN %s) AND (id IN %s)' % (not_queued, not_finished))
137    if finished:
138        where.append('id NOT IN ' + not_finished)
139    return {'where': [' OR '.join(['(%s)' % x for x in where])]}
140
141
142def extra_job_type_filters(extra_args, suite=False,
143                           sub=False, standalone=False):
144    """\
145    Generate a SQL WHERE clause for job status filtering, and return it in
146    a dict of keyword args to pass to query.extra().
147
148    param extra_args: a dict of existing extra_args.
149
150    No more than one of the parameters should be passed as True:
151    * suite: job which is parent of other jobs
152    * sub: job with a parent job
153    * standalone: job with no child or parent jobs
154    """
155    assert not ((suite and sub) or
156                (suite and standalone) or
157                (sub and standalone)), ('Cannot specify more than one '
158                                        'filter to this function')
159
160    where = extra_args.get('where', [])
161    parent_job_id = ('DISTINCT parent_job_id')
162    child_job_id = ('id')
163    filter_common = ('(SELECT %s FROM afe_jobs '
164                     'WHERE parent_job_id IS NOT NULL)')
165
166    if suite:
167        where.append('id IN ' + filter_common % parent_job_id)
168    elif sub:
169        where.append('id IN ' + filter_common % child_job_id)
170    elif standalone:
171        where.append('NOT EXISTS (SELECT 1 from afe_jobs AS sub_query '
172                     'WHERE parent_job_id IS NOT NULL'
173                     ' AND (sub_query.parent_job_id=afe_jobs.id'
174                     ' OR sub_query.id=afe_jobs.id))')
175    else:
176        return extra_args
177
178    extra_args['where'] = where
179    return extra_args
180
181
182
183def extra_host_filters(multiple_labels=()):
184    """\
185    Generate SQL WHERE clauses for matching hosts in an intersection of
186    labels.
187    """
188    extra_args = {}
189    where_str = ('afe_hosts.id in (select host_id from afe_hosts_labels '
190                 'where label_id=%s)')
191    extra_args['where'] = [where_str] * len(multiple_labels)
192    extra_args['params'] = [models.Label.smart_get(label).id
193                            for label in multiple_labels]
194    return extra_args
195
196
197def get_host_query(multiple_labels, exclude_only_if_needed_labels,
198                   valid_only, filter_data):
199    if valid_only:
200        query = models.Host.valid_objects.all()
201    else:
202        query = models.Host.objects.all()
203
204    if exclude_only_if_needed_labels:
205        only_if_needed_labels = models.Label.valid_objects.filter(
206            only_if_needed=True)
207        if only_if_needed_labels.count() > 0:
208            only_if_needed_ids = ','.join(
209                    str(label['id'])
210                    for label in only_if_needed_labels.values('id'))
211            query = models.Host.objects.add_join(
212                query, 'afe_hosts_labels', join_key='host_id',
213                join_condition=('afe_hosts_labels_exclude_OIN.label_id IN (%s)'
214                                % only_if_needed_ids),
215                suffix='_exclude_OIN', exclude=True)
216    try:
217        assert 'extra_args' not in filter_data
218        filter_data['extra_args'] = extra_host_filters(multiple_labels)
219        return models.Host.query_objects(filter_data, initial_query=query)
220    except models.Label.DoesNotExist:
221        return models.Host.objects.none()
222
223
224class InconsistencyException(Exception):
225    'Raised when a list of objects does not have a consistent value'
226
227
228def get_consistent_value(objects, field):
229    if not objects:
230        # well a list of nothing is consistent
231        return None
232
233    value = getattr(objects[0], field)
234    for obj in objects:
235        this_value = getattr(obj, field)
236        if this_value != value:
237            raise InconsistencyException(objects[0], obj)
238    return value
239
240
241def afe_test_dict_to_test_object(test_dict):
242    if not isinstance(test_dict, dict):
243        return test_dict
244
245    numerized_dict = {}
246    for key, value in test_dict.iteritems():
247        try:
248            numerized_dict[key] = int(value)
249        except (ValueError, TypeError):
250            numerized_dict[key] = value
251
252    return type('TestObject', (object,), numerized_dict)
253
254
255def _check_is_server_test(test_type):
256    """Checks if the test type is a server test.
257
258    @param test_type The test type in enum integer or string.
259
260    @returns A boolean to identify if the test type is server test.
261    """
262    if test_type is not None:
263        if isinstance(test_type, basestring):
264            try:
265                test_type = control_data.CONTROL_TYPE.get_value(test_type)
266            except AttributeError:
267                return False
268        return (test_type == control_data.CONTROL_TYPE.SERVER)
269    return False
270
271
272def prepare_generate_control_file(tests, profilers, db_tests=True):
273    if db_tests:
274        test_objects = [models.Test.smart_get(test) for test in tests]
275    else:
276        test_objects = [afe_test_dict_to_test_object(test) for test in tests]
277
278    profiler_objects = [models.Profiler.smart_get(profiler)
279                        for profiler in profilers]
280    # ensure tests are all the same type
281    try:
282        test_type = get_consistent_value(test_objects, 'test_type')
283    except InconsistencyException, exc:
284        test1, test2 = exc.args
285        raise model_logic.ValidationError(
286            {'tests' : 'You cannot run both test_suites and server-side '
287             'tests together (tests %s and %s differ' % (
288            test1.name, test2.name)})
289
290    is_server = _check_is_server_test(test_type)
291    if test_objects:
292        synch_count = max(test.sync_count for test in test_objects)
293    else:
294        synch_count = 1
295
296    if db_tests:
297        dependencies = set(label.name for label
298                           in models.Label.objects.filter(test__in=test_objects))
299    else:
300        dependencies = reduce(
301                set.union, [set(test.dependencies) for test in test_objects])
302
303    cf_info = dict(is_server=is_server, synch_count=synch_count,
304                   dependencies=list(dependencies))
305    return cf_info, test_objects, profiler_objects
306
307
308def check_job_dependencies(host_objects, job_dependencies):
309    """
310    Check that a set of machines satisfies a job's dependencies.
311    host_objects: list of models.Host objects
312    job_dependencies: list of names of labels
313    """
314    # check that hosts satisfy dependencies
315    host_ids = [host.id for host in host_objects]
316    hosts_in_job = models.Host.objects.filter(id__in=host_ids)
317    ok_hosts = hosts_in_job
318    for index, dependency in enumerate(job_dependencies):
319        if not provision.is_for_special_action(dependency):
320            ok_hosts = ok_hosts.filter(labels__name=dependency)
321    failing_hosts = (set(host.hostname for host in host_objects) -
322                     set(host.hostname for host in ok_hosts))
323    if failing_hosts:
324        raise model_logic.ValidationError(
325            {'hosts' : 'Host(s) failed to meet job dependencies (' +
326                       (', '.join(job_dependencies)) + '): ' +
327                       (', '.join(failing_hosts))})
328
329
330def check_job_metahost_dependencies(metahost_objects, job_dependencies):
331    """
332    Check that at least one machine within the metahost spec satisfies the job's
333    dependencies.
334
335    @param metahost_objects A list of label objects representing the metahosts.
336    @param job_dependencies A list of strings of the required label names.
337    @raises NoEligibleHostException If a metahost cannot run the job.
338    """
339    for metahost in metahost_objects:
340        hosts = models.Host.objects.filter(labels=metahost)
341        for label_name in job_dependencies:
342            if not provision.is_for_special_action(label_name):
343                hosts = hosts.filter(labels__name=label_name)
344        if not any(hosts):
345            raise error.NoEligibleHostException("No hosts within %s satisfy %s."
346                    % (metahost.name, ', '.join(job_dependencies)))
347
348
349def _execution_key_for(host_queue_entry):
350    return (host_queue_entry.job.id, host_queue_entry.execution_subdir)
351
352
353def check_abort_synchronous_jobs(host_queue_entries):
354    # ensure user isn't aborting part of a synchronous autoserv execution
355    count_per_execution = {}
356    for queue_entry in host_queue_entries:
357        key = _execution_key_for(queue_entry)
358        count_per_execution.setdefault(key, 0)
359        count_per_execution[key] += 1
360
361    for queue_entry in host_queue_entries:
362        if not queue_entry.execution_subdir:
363            continue
364        execution_count = count_per_execution[_execution_key_for(queue_entry)]
365        if execution_count < queue_entry.job.synch_count:
366            raise model_logic.ValidationError(
367                {'' : 'You cannot abort part of a synchronous job execution '
368                      '(%d/%s), %d included, %d expected'
369                      % (queue_entry.job.id, queue_entry.execution_subdir,
370                         execution_count, queue_entry.job.synch_count)})
371
372
373def check_modify_host(update_data):
374    """
375    Sanity check modify_host* requests.
376
377    @param update_data: A dictionary with the changes to make to a host
378            or hosts.
379    """
380    # Only the scheduler (monitor_db) is allowed to modify Host status.
381    # Otherwise race conditions happen as a hosts state is changed out from
382    # beneath tasks being run on a host.
383    if 'status' in update_data:
384        raise model_logic.ValidationError({
385                'status': 'Host status can not be modified by the frontend.'})
386
387
388def check_modify_host_locking(host, update_data):
389    """
390    Checks when locking/unlocking has been requested if the host is already
391    locked/unlocked.
392
393    @param host: models.Host object to be modified
394    @param update_data: A dictionary with the changes to make to the host.
395    """
396    locked = update_data.get('locked', None)
397    lock_reason = update_data.get('lock_reason', None)
398    if locked is not None:
399        if locked and host.locked:
400            raise model_logic.ValidationError({
401                    'locked': 'Host %s already locked by %s on %s.' %
402                    (host.hostname, host.locked_by, host.lock_time)})
403        if not locked and not host.locked:
404            raise model_logic.ValidationError({
405                    'locked': 'Host %s already unlocked.' % host.hostname})
406        if locked and not lock_reason and not host.locked:
407            raise model_logic.ValidationError({
408                    'locked': 'Please provide a reason for locking Host %s' %
409                    host.hostname})
410
411
412def get_motd():
413    dirname = os.path.dirname(__file__)
414    filename = os.path.join(dirname, "..", "..", "motd.txt")
415    text = ''
416    try:
417        fp = open(filename, "r")
418        try:
419            text = fp.read()
420        finally:
421            fp.close()
422    except:
423        pass
424
425    return text
426
427
428def _get_metahost_counts(metahost_objects):
429    metahost_counts = {}
430    for metahost in metahost_objects:
431        metahost_counts.setdefault(metahost, 0)
432        metahost_counts[metahost] += 1
433    return metahost_counts
434
435
436def get_job_info(job, preserve_metahosts=False, queue_entry_filter_data=None):
437    hosts = []
438    one_time_hosts = []
439    meta_hosts = []
440    hostless = False
441
442    queue_entries = job.hostqueueentry_set.all()
443    if queue_entry_filter_data:
444        queue_entries = models.HostQueueEntry.query_objects(
445            queue_entry_filter_data, initial_query=queue_entries)
446
447    for queue_entry in queue_entries:
448        if (queue_entry.host and (preserve_metahosts or
449                                  not queue_entry.meta_host)):
450            if queue_entry.deleted:
451                continue
452            if queue_entry.host.invalid:
453                one_time_hosts.append(queue_entry.host)
454            else:
455                hosts.append(queue_entry.host)
456        elif queue_entry.meta_host:
457            meta_hosts.append(queue_entry.meta_host)
458        else:
459            hostless = True
460
461    meta_host_counts = _get_metahost_counts(meta_hosts)
462
463    info = dict(dependencies=[label.name for label
464                              in job.dependency_labels.all()],
465                hosts=hosts,
466                meta_hosts=meta_hosts,
467                meta_host_counts=meta_host_counts,
468                one_time_hosts=one_time_hosts,
469                hostless=hostless)
470    return info
471
472
473def check_for_duplicate_hosts(host_objects):
474    host_counts = collections.Counter(host_objects)
475    duplicate_hostnames = {host.hostname
476                           for host, count in host_counts.iteritems()
477                           if count > 1}
478    if duplicate_hostnames:
479        raise model_logic.ValidationError(
480                {'hosts' : 'Duplicate hosts: %s'
481                 % ', '.join(duplicate_hostnames)})
482
483
484def create_new_job(owner, options, host_objects, metahost_objects):
485    all_host_objects = host_objects + metahost_objects
486    dependencies = options.get('dependencies', [])
487    synch_count = options.get('synch_count')
488
489    if synch_count is not None and synch_count > len(all_host_objects):
490        raise model_logic.ValidationError(
491                {'hosts':
492                 'only %d hosts provided for job with synch_count = %d' %
493                 (len(all_host_objects), synch_count)})
494
495    check_for_duplicate_hosts(host_objects)
496
497    for label_name in dependencies:
498        if provision.is_for_special_action(label_name):
499            # TODO: We could save a few queries
500            # if we had a bulk ensure-label-exists function, which used
501            # a bulk .get() call. The win is probably very small.
502            _ensure_label_exists(label_name)
503
504    # This only checks targeted hosts, not hosts eligible due to the metahost
505    check_job_dependencies(host_objects, dependencies)
506    check_job_metahost_dependencies(metahost_objects, dependencies)
507
508    options['dependencies'] = list(
509            models.Label.objects.filter(name__in=dependencies))
510
511    job = models.Job.create(owner=owner, options=options,
512                            hosts=all_host_objects)
513    job.queue(all_host_objects,
514              is_template=options.get('is_template', False))
515    return job.id
516
517
518def _ensure_label_exists(name):
519    """
520    Ensure that a label called |name| exists in the Django models.
521
522    This function is to be called from within afe rpcs only, as an
523    alternative to server.cros.provision.ensure_label_exists(...). It works
524    by Django model manipulation, rather than by making another create_label
525    rpc call.
526
527    @param name: the label to check for/create.
528    @raises ValidationError: There was an error in the response that was
529                             not because the label already existed.
530    @returns True is a label was created, False otherwise.
531    """
532    # Make sure this function is not called on shards but only on master.
533    assert not server_utils.is_shard()
534    try:
535        models.Label.objects.get(name=name)
536    except models.Label.DoesNotExist:
537        try:
538            new_label = models.Label.objects.create(name=name)
539            new_label.save()
540            return True
541        except django.db.utils.IntegrityError as e:
542            # It is possible that another suite/test already
543            # created the label between the check and save.
544            if DUPLICATE_KEY_MSG in str(e):
545                return False
546            else:
547                raise
548    return False
549
550
551def find_platform(host):
552    """
553    Figure out the platform name for the given host
554    object.  If none, the return value for either will be None.
555
556    @returns platform name for the given host.
557    """
558    platforms = [label.name for label in host.label_list if label.platform]
559    if not platforms:
560        platform = None
561    else:
562        platform = platforms[0]
563    if len(platforms) > 1:
564        raise ValueError('Host %s has more than one platform: %s' %
565                         (host.hostname, ', '.join(platforms)))
566    return platform
567
568
569# support for get_host_queue_entries_and_special_tasks()
570
571def _common_entry_to_dict(entry, type, job_dict, exec_path, status, started_on):
572    return dict(type=type,
573                host=entry['host'],
574                job=job_dict,
575                execution_path=exec_path,
576                status=status,
577                started_on=started_on,
578                id=str(entry['id']) + type,
579                oid=entry['id'])
580
581
582def _special_task_to_dict(task, queue_entries):
583    """Transforms a special task dictionary to another form of dictionary.
584
585    @param task           Special task as a dictionary type
586    @param queue_entries  Host queue entries as a list of dictionaries.
587
588    @return Transformed dictionary for a special task.
589    """
590    job_dict = None
591    if task['queue_entry']:
592        # Scan queue_entries to get the job detail info.
593        for qentry in queue_entries:
594            if task['queue_entry']['id'] == qentry['id']:
595                job_dict = qentry['job']
596                break
597        # If not found, get it from DB.
598        if job_dict is None:
599            job = models.Job.objects.get(id=task['queue_entry']['job'])
600            job_dict = job.get_object_dict()
601
602    exec_path = server_utils.get_special_task_exec_path(
603            task['host']['hostname'], task['id'], task['task'],
604            time_utils.time_string_to_datetime(task['time_requested']))
605    status = server_utils.get_special_task_status(
606            task['is_complete'], task['success'], task['is_active'])
607    return _common_entry_to_dict(task, task['task'], job_dict,
608            exec_path, status, task['time_started'])
609
610
611def _queue_entry_to_dict(queue_entry):
612    job_dict = queue_entry['job']
613    tag = server_utils.get_job_tag(job_dict['id'], job_dict['owner'])
614    exec_path = server_utils.get_hqe_exec_path(tag,
615                                               queue_entry['execution_subdir'])
616    return _common_entry_to_dict(queue_entry, 'Job', job_dict, exec_path,
617            queue_entry['status'], queue_entry['started_on'])
618
619
620def prepare_host_queue_entries_and_special_tasks(interleaved_entries,
621                                                 queue_entries):
622    """
623    Prepare for serialization the interleaved entries of host queue entries
624    and special tasks.
625    Each element in the entries is a dictionary type.
626    The special task dictionary has only a job id for a job and lacks
627    the detail of the job while the host queue entry dictionary has.
628    queue_entries is used to look up the job detail info.
629
630    @param interleaved_entries  Host queue entries and special tasks as a list
631                                of dictionaries.
632    @param queue_entries        Host queue entries as a list of dictionaries.
633
634    @return A post-processed list of dictionaries that is to be serialized.
635    """
636    dict_list = []
637    for e in interleaved_entries:
638        # Distinguish the two mixed entries based on the existence of
639        # the key "task". If an entry has the key, the entry is for
640        # special task. Otherwise, host queue entry.
641        if 'task' in e:
642            dict_list.append(_special_task_to_dict(e, queue_entries))
643        else:
644            dict_list.append(_queue_entry_to_dict(e))
645    return prepare_for_serialization(dict_list)
646
647
648def _compute_next_job_for_tasks(queue_entries, special_tasks):
649    """
650    For each task, try to figure out the next job that ran after that task.
651    This is done using two pieces of information:
652    * if the task has a queue entry, we can use that entry's job ID.
653    * if the task has a time_started, we can try to compare that against the
654      started_on field of queue_entries. this isn't guaranteed to work perfectly
655      since queue_entries may also have null started_on values.
656    * if the task has neither, or if use of time_started fails, just use the
657      last computed job ID.
658
659    @param queue_entries    Host queue entries as a list of dictionaries.
660    @param special_tasks    Special tasks as a list of dictionaries.
661    """
662    next_job_id = None # most recently computed next job
663    hqe_index = 0 # index for scanning by started_on times
664    for task in special_tasks:
665        if task['queue_entry']:
666            next_job_id = task['queue_entry']['job']
667        elif task['time_started'] is not None:
668            for queue_entry in queue_entries[hqe_index:]:
669                if queue_entry['started_on'] is None:
670                    continue
671                t1 = time_utils.time_string_to_datetime(
672                        queue_entry['started_on'])
673                t2 = time_utils.time_string_to_datetime(task['time_started'])
674                if t1 < t2:
675                    break
676                next_job_id = queue_entry['job']['id']
677
678        task['next_job_id'] = next_job_id
679
680        # advance hqe_index to just after next_job_id
681        if next_job_id is not None:
682            for queue_entry in queue_entries[hqe_index:]:
683                if queue_entry['job']['id'] < next_job_id:
684                    break
685                hqe_index += 1
686
687
688def interleave_entries(queue_entries, special_tasks):
689    """
690    Both lists should be ordered by descending ID.
691    """
692    _compute_next_job_for_tasks(queue_entries, special_tasks)
693
694    # start with all special tasks that've run since the last job
695    interleaved_entries = []
696    for task in special_tasks:
697        if task['next_job_id'] is not None:
698            break
699        interleaved_entries.append(task)
700
701    # now interleave queue entries with the remaining special tasks
702    special_task_index = len(interleaved_entries)
703    for queue_entry in queue_entries:
704        interleaved_entries.append(queue_entry)
705        # add all tasks that ran between this job and the previous one
706        for task in special_tasks[special_task_index:]:
707            if task['next_job_id'] < queue_entry['job']['id']:
708                break
709            interleaved_entries.append(task)
710            special_task_index += 1
711
712    return interleaved_entries
713
714
715def bucket_hosts_by_shard(host_objs, rpc_hostnames=False):
716    """Figure out which hosts are on which shards.
717
718    @param host_objs: A list of host objects.
719    @param rpc_hostnames: If True, the rpc_hostnames of a shard are returned
720        instead of the 'real' shard hostnames. This only matters for testing
721        environments.
722
723    @return: A map of shard hostname: list of hosts on the shard.
724    """
725    shard_host_map = collections.defaultdict(list)
726    for host in host_objs:
727        if host.shard:
728            shard_name = (host.shard.rpc_hostname() if rpc_hostnames
729                          else host.shard.hostname)
730            shard_host_map[shard_name].append(host.hostname)
731    return shard_host_map
732
733
734def create_job_common(
735        name,
736        priority,
737        control_type,
738        control_file=None,
739        hosts=(),
740        meta_hosts=(),
741        one_time_hosts=(),
742        synch_count=None,
743        is_template=False,
744        timeout=None,
745        timeout_mins=None,
746        max_runtime_mins=None,
747        run_verify=True,
748        email_list='',
749        dependencies=(),
750        reboot_before=None,
751        reboot_after=None,
752        parse_failed_repair=None,
753        hostless=False,
754        keyvals=None,
755        drone_set=None,
756        parent_job_id=None,
757        test_retry=0,
758        run_reset=True,
759        require_ssp=None):
760    #pylint: disable-msg=C0111
761    """
762    Common code between creating "standard" jobs and creating parameterized jobs
763    """
764    # input validation
765    host_args_passed = any((hosts, meta_hosts, one_time_hosts))
766    if hostless:
767        if host_args_passed:
768            raise model_logic.ValidationError({
769                    'hostless': 'Hostless jobs cannot include any hosts!'})
770        if control_type != control_data.CONTROL_TYPE_NAMES.SERVER:
771            raise model_logic.ValidationError({
772                    'control_type': 'Hostless jobs cannot use client-side '
773                                    'control files'})
774    elif not host_args_passed:
775        raise model_logic.ValidationError({
776            'arguments' : "For host jobs, you must pass at least one of"
777                          " 'hosts', 'meta_hosts', 'one_time_hosts'."
778            })
779    label_objects = list(models.Label.objects.filter(name__in=meta_hosts))
780
781    # convert hostnames & meta hosts to host/label objects
782    host_objects = models.Host.smart_get_bulk(hosts)
783    _validate_host_job_sharding(host_objects)
784    for host in one_time_hosts:
785        this_host = models.Host.create_one_time_host(host)
786        host_objects.append(this_host)
787
788    metahost_objects = []
789    meta_host_labels_by_name = {label.name: label for label in label_objects}
790    for label_name in meta_hosts:
791        if label_name in meta_host_labels_by_name:
792            metahost_objects.append(meta_host_labels_by_name[label_name])
793        else:
794            raise model_logic.ValidationError(
795                {'meta_hosts' : 'Label "%s" not found' % label_name})
796
797    options = dict(name=name,
798                   priority=priority,
799                   control_file=control_file,
800                   control_type=control_type,
801                   is_template=is_template,
802                   timeout=timeout,
803                   timeout_mins=timeout_mins,
804                   max_runtime_mins=max_runtime_mins,
805                   synch_count=synch_count,
806                   run_verify=run_verify,
807                   email_list=email_list,
808                   dependencies=dependencies,
809                   reboot_before=reboot_before,
810                   reboot_after=reboot_after,
811                   parse_failed_repair=parse_failed_repair,
812                   keyvals=keyvals,
813                   drone_set=drone_set,
814                   parent_job_id=parent_job_id,
815                   test_retry=test_retry,
816                   run_reset=run_reset,
817                   require_ssp=require_ssp)
818
819    return create_new_job(owner=models.User.current_user().login,
820                          options=options,
821                          host_objects=host_objects,
822                          metahost_objects=metahost_objects)
823
824
825def _validate_host_job_sharding(host_objects):
826    """Check that the hosts obey job sharding rules."""
827    if not (server_utils.is_shard()
828            or _allowed_hosts_for_master_job(host_objects)):
829        shard_host_map = bucket_hosts_by_shard(host_objects)
830        raise ValueError(
831                'The following hosts are on shard(s), please create '
832                'seperate jobs for hosts on each shard: %s ' %
833                shard_host_map)
834
835
836def _allowed_hosts_for_master_job(host_objects):
837    """Check that the hosts are allowed for a job on master."""
838    # We disallow the following jobs on master:
839    #   num_shards > 1: this is a job spanning across multiple shards.
840    #   num_shards == 1 but number of hosts on shard is less
841    #   than total number of hosts: this is a job that spans across
842    #   one shard and the master.
843    shard_host_map = bucket_hosts_by_shard(host_objects)
844    num_shards = len(shard_host_map)
845    if num_shards > 1:
846        return False
847    if num_shards == 1:
848        hosts_on_shard = shard_host_map.values()[0]
849        assert len(hosts_on_shard) <= len(host_objects)
850        return len(hosts_on_shard) == len(host_objects)
851    else:
852        return True
853
854
855def encode_ascii(control_file):
856    """Force a control file to only contain ascii characters.
857
858    @param control_file: Control file to encode.
859
860    @returns the control file in an ascii encoding.
861
862    @raises error.ControlFileMalformed: if encoding fails.
863    """
864    try:
865        return control_file.encode('ascii')
866    except UnicodeDecodeError as e:
867        raise error.ControlFileMalformed(str(e))
868
869
870def get_wmatrix_url():
871    """Get wmatrix url from config file.
872
873    @returns the wmatrix url or an empty string.
874    """
875    return global_config.global_config.get_config_value('AUTOTEST_WEB',
876                                                        'wmatrix_url',
877                                                        default='')
878
879
880def inject_times_to_filter(start_time_key=None, end_time_key=None,
881                         start_time_value=None, end_time_value=None,
882                         **filter_data):
883    """Inject the key value pairs of start and end time if provided.
884
885    @param start_time_key: A string represents the filter key of start_time.
886    @param end_time_key: A string represents the filter key of end_time.
887    @param start_time_value: Start_time value.
888    @param end_time_value: End_time value.
889
890    @returns the injected filter_data.
891    """
892    if start_time_value:
893        filter_data[start_time_key] = start_time_value
894    if end_time_value:
895        filter_data[end_time_key] = end_time_value
896    return filter_data
897
898
899def inject_times_to_hqe_special_tasks_filters(filter_data_common,
900                                              start_time, end_time):
901    """Inject start and end time to hqe and special tasks filters.
902
903    @param filter_data_common: Common filter for hqe and special tasks.
904    @param start_time_key: A string represents the filter key of start_time.
905    @param end_time_key: A string represents the filter key of end_time.
906
907    @returns a pair of hqe and special tasks filters.
908    """
909    filter_data_special_tasks = filter_data_common.copy()
910    return (inject_times_to_filter('started_on__gte', 'started_on__lte',
911                                   start_time, end_time, **filter_data_common),
912           inject_times_to_filter('time_started__gte', 'time_started__lte',
913                                  start_time, end_time,
914                                  **filter_data_special_tasks))
915
916
917def retrieve_shard(shard_hostname):
918    """
919    Retrieves the shard with the given hostname from the database.
920
921    @param shard_hostname: Hostname of the shard to retrieve
922
923    @raises models.Shard.DoesNotExist, if no shard with this hostname was found.
924
925    @returns: Shard object
926    """
927    return models.Shard.smart_get(shard_hostname)
928
929
930def find_records_for_shard(shard, known_job_ids, known_host_ids):
931    """Find records that should be sent to a shard.
932
933    @param shard: Shard to find records for.
934    @param known_job_ids: List of ids of jobs the shard already has.
935    @param known_host_ids: List of ids of hosts the shard already has.
936
937    @returns: Tuple of lists:
938              (hosts, jobs, suite_job_keyvals, invalid_host_ids)
939    """
940    hosts, invalid_host_ids = models.Host.assign_to_shard(
941            shard, known_host_ids)
942    jobs = models.Job.assign_to_shard(shard, known_job_ids)
943    parent_job_ids = [job.parent_job_id for job in jobs]
944    suite_job_keyvals = models.JobKeyval.objects.filter(
945            job_id__in=parent_job_ids)
946    return hosts, jobs, suite_job_keyvals, invalid_host_ids
947
948
949def _persist_records_with_type_sent_from_shard(
950    shard, records, record_type, *args, **kwargs):
951    """
952    Handle records of a specified type that were sent to the shard master.
953
954    @param shard: The shard the records were sent from.
955    @param records: The records sent in their serialized format.
956    @param record_type: Type of the objects represented by records.
957    @param args: Additional arguments that will be passed on to the sanity
958                 checks.
959    @param kwargs: Additional arguments that will be passed on to the sanity
960                  checks.
961
962    @raises error.UnallowedRecordsSentToMaster if any of the sanity checks fail.
963
964    @returns: List of primary keys of the processed records.
965    """
966    pks = []
967    for serialized_record in records:
968        pk = serialized_record['id']
969        try:
970            current_record = record_type.objects.get(pk=pk)
971        except record_type.DoesNotExist:
972            raise error.UnallowedRecordsSentToMaster(
973                'Object with pk %s of type %s does not exist on master.' % (
974                    pk, record_type))
975
976        try:
977            current_record.sanity_check_update_from_shard(
978                shard, serialized_record, *args, **kwargs)
979        except error.IgnorableUnallowedRecordsSentToMaster:
980            # An illegal record change was attempted, but it was of a non-fatal
981            # variety. Silently skip this record.
982            pass
983        else:
984            current_record.update_from_serialized(serialized_record)
985            pks.append(pk)
986
987    return pks
988
989
990def persist_records_sent_from_shard(shard, jobs, hqes):
991    """
992    Sanity checking then saving serialized records sent to master from shard.
993
994    During heartbeats shards upload jobs and hostqueuentries. This performs
995    some sanity checks on these and then updates the existing records for those
996    entries with the updated ones from the heartbeat.
997
998    The sanity checks include:
999    - Checking if the objects sent already exist on the master.
1000    - Checking if the objects sent were assigned to this shard.
1001    - hostqueueentries must be sent together with their jobs.
1002
1003    @param shard: The shard the records were sent from.
1004    @param jobs: The jobs the shard sent.
1005    @param hqes: The hostqueuentries the shart sent.
1006
1007    @raises error.UnallowedRecordsSentToMaster if any of the sanity checks fail.
1008    """
1009    job_ids_persisted = _persist_records_with_type_sent_from_shard(
1010            shard, jobs, models.Job)
1011    _persist_records_with_type_sent_from_shard(
1012            shard, hqes, models.HostQueueEntry,
1013            job_ids_sent=job_ids_persisted)
1014
1015
1016def forward_single_host_rpc_to_shard(func):
1017    """This decorator forwards rpc calls that modify a host to a shard.
1018
1019    If a host is assigned to a shard, rpcs that change his attributes should be
1020    forwarded to the shard.
1021
1022    This assumes the first argument of the function represents a host id.
1023
1024    @param func: The function to decorate
1025
1026    @returns: The function to replace func with.
1027    """
1028    def replacement(**kwargs):
1029        # Only keyword arguments can be accepted here, as we need the argument
1030        # names to send the rpc. serviceHandler always provides arguments with
1031        # their keywords, so this is not a problem.
1032
1033        # A host record (identified by kwargs['id']) can be deleted in
1034        # func(). Therefore, we should save the data that can be needed later
1035        # before func() is called.
1036        shard_hostname = None
1037        host = models.Host.smart_get(kwargs['id'])
1038        if host and host.shard:
1039            shard_hostname = host.shard.rpc_hostname()
1040        ret = func(**kwargs)
1041        if shard_hostname and not server_utils.is_shard():
1042            run_rpc_on_multiple_hostnames(func.func_name,
1043                                          [shard_hostname],
1044                                          **kwargs)
1045        return ret
1046
1047    return replacement
1048
1049
1050def fanout_rpc(host_objs, rpc_name, include_hostnames=True, **kwargs):
1051    """Fanout the given rpc to shards of given hosts.
1052
1053    @param host_objs: Host objects for the rpc.
1054    @param rpc_name: The name of the rpc.
1055    @param include_hostnames: If True, include the hostnames in the kwargs.
1056        Hostnames are not always necessary, this functions is designed to
1057        send rpcs to the shard a host is on, the rpcs themselves could be
1058        related to labels, acls etc.
1059    @param kwargs: The kwargs for the rpc.
1060    """
1061    # Figure out which hosts are on which shards.
1062    shard_host_map = bucket_hosts_by_shard(
1063            host_objs, rpc_hostnames=True)
1064
1065    # Execute the rpc against the appropriate shards.
1066    for shard, hostnames in shard_host_map.iteritems():
1067        if include_hostnames:
1068            kwargs['hosts'] = hostnames
1069        try:
1070            run_rpc_on_multiple_hostnames(rpc_name, [shard], **kwargs)
1071        except:
1072            ei = sys.exc_info()
1073            new_exc = error.RPCException('RPC %s failed on shard %s due to '
1074                    '%s: %s' % (rpc_name, shard, ei[0].__name__, ei[1]))
1075            raise new_exc.__class__, new_exc, ei[2]
1076
1077
1078def run_rpc_on_multiple_hostnames(rpc_call, shard_hostnames, **kwargs):
1079    """Runs an rpc to multiple AFEs
1080
1081    This is i.e. used to propagate changes made to hosts after they are assigned
1082    to a shard.
1083
1084    @param rpc_call: Name of the rpc endpoint to call.
1085    @param shard_hostnames: List of hostnames to run the rpcs on.
1086    @param **kwargs: Keyword arguments to pass in the rpcs.
1087    """
1088    # Make sure this function is not called on shards but only on master.
1089    assert not server_utils.is_shard()
1090    for shard_hostname in shard_hostnames:
1091        afe = frontend_wrappers.RetryingAFE(server=shard_hostname,
1092                                            user=thread_local.get_user())
1093        afe.run(rpc_call, **kwargs)
1094
1095
1096def get_label(name):
1097    """Gets a label object using a given name.
1098
1099    @param name: Label name.
1100    @raises model.Label.DoesNotExist: when there is no label matching
1101                                      the given name.
1102    @return: a label object matching the given name.
1103    """
1104    try:
1105        label = models.Label.smart_get(name)
1106    except models.Label.DoesNotExist:
1107        return None
1108    return label
1109
1110
1111# TODO: hide the following rpcs under is_moblab
1112def moblab_only(func):
1113    """Ensure moblab specific functions only run on Moblab devices."""
1114    def verify(*args, **kwargs):
1115        if not server_utils.is_moblab():
1116            raise error.RPCException('RPC: %s can only run on Moblab Systems!',
1117                                     func.__name__)
1118        return func(*args, **kwargs)
1119    return verify
1120
1121
1122def route_rpc_to_master(func):
1123    """Route RPC to master AFE.
1124
1125    When a shard receives an RPC decorated by this, the RPC is just
1126    forwarded to the master.
1127    When the master gets the RPC, the RPC function is executed.
1128
1129    @param func: An RPC function to decorate
1130
1131    @returns: A function replacing the RPC func.
1132    """
1133    argspec = inspect.getargspec(func)
1134    if argspec.varargs is not None:
1135        raise Exception('RPC function must not have *args.')
1136
1137    @wraps(func)
1138    def replacement(*args, **kwargs):
1139        """We need special handling when decorating an RPC that can be called
1140        directly using positional arguments.
1141
1142        One example is rpc_interface.create_job().
1143        rpc_interface.create_job_page_handler() calls the function using both
1144        positional and keyword arguments.  Since frontend.RpcClient.run()
1145        takes only keyword arguments for an RPC, positional arguments of the
1146        RPC function need to be transformed into keyword arguments.
1147        """
1148        kwargs = _convert_to_kwargs_only(func, args, kwargs)
1149        if server_utils.is_shard():
1150            afe = frontend_wrappers.RetryingAFE(
1151                    server=server_utils.get_global_afe_hostname(),
1152                    user=thread_local.get_user())
1153            return afe.run(func.func_name, **kwargs)
1154        return func(**kwargs)
1155
1156    return replacement
1157
1158
1159def _convert_to_kwargs_only(func, args, kwargs):
1160    """Convert a function call's arguments to a kwargs dict.
1161
1162    This is best illustrated with an example.  Given:
1163
1164    def foo(a, b, **kwargs):
1165        pass
1166    _to_kwargs(foo, (1, 2), {'c': 3})  # corresponding to foo(1, 2, c=3)
1167
1168        foo(**kwargs)
1169
1170    @param func: function whose signature to use
1171    @param args: positional arguments of call
1172    @param kwargs: keyword arguments of call
1173
1174    @returns: kwargs dict
1175    """
1176    argspec = inspect.getargspec(func)
1177    # callargs looks like {'a': 1, 'b': 2, 'kwargs': {'c': 3}}
1178    callargs = inspect.getcallargs(func, *args, **kwargs)
1179    if argspec.keywords is None:
1180        kwargs = {}
1181    else:
1182        kwargs = callargs.pop(argspec.keywords)
1183    kwargs.update(callargs)
1184    return kwargs
1185
1186
1187def get_sample_dut(board, pool):
1188    """Get a dut with the given board and pool.
1189
1190    This method is used to help to locate a dut with the given board and pool.
1191    The dut then can be used to identify a devserver in the same subnet.
1192
1193    @param board: Name of the board.
1194    @param pool: Name of the pool.
1195
1196    @return: Name of a dut with the given board and pool.
1197    """
1198    if not (dev_server.PREFER_LOCAL_DEVSERVER and pool and board):
1199        return None
1200    hosts = list(get_host_query(
1201        multiple_labels=('pool:%s' % pool, 'board:%s' % board),
1202        exclude_only_if_needed_labels=False,
1203        valid_only=True,
1204        filter_data={},
1205    ))
1206    if not hosts:
1207        return None
1208    else:
1209        return hosts[0].hostname
1210