• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1"""
2Extensions to Django's model logic.
3"""
4
5import django.core.exceptions
6from django.db import backend
7from django.db import connection
8from django.db import connections
9from django.db import models as dbmodels
10from django.db import transaction
11from django.db.models.sql import query
12import django.db.models.sql.where
13from django.utils import datastructures
14from autotest_lib.client.common_lib.cros.graphite import autotest_stats
15from autotest_lib.frontend.afe import rdb_model_extensions
16
17
18class ValidationError(django.core.exceptions.ValidationError):
19    """\
20    Data validation error in adding or updating an object. The associated
21    value is a dictionary mapping field names to error strings.
22    """
23
24def _quote_name(name):
25    """Shorthand for connection.ops.quote_name()."""
26    return connection.ops.quote_name(name)
27
28
29class LeasedHostManager(dbmodels.Manager):
30    """Query manager for unleased, unlocked hosts.
31    """
32    def get_query_set(self):
33        return (super(LeasedHostManager, self).get_query_set().filter(
34                leased=0, locked=0))
35
36
37class ExtendedManager(dbmodels.Manager):
38    """\
39    Extended manager supporting subquery filtering.
40    """
41
42    class CustomQuery(query.Query):
43        def __init__(self, *args, **kwargs):
44            super(ExtendedManager.CustomQuery, self).__init__(*args, **kwargs)
45            self._custom_joins = []
46
47
48        def clone(self, klass=None, **kwargs):
49            obj = super(ExtendedManager.CustomQuery, self).clone(klass)
50            obj._custom_joins = list(self._custom_joins)
51            return obj
52
53
54        def combine(self, rhs, connector):
55            super(ExtendedManager.CustomQuery, self).combine(rhs, connector)
56            if hasattr(rhs, '_custom_joins'):
57                self._custom_joins.extend(rhs._custom_joins)
58
59
60        def add_custom_join(self, table, condition, join_type,
61                            condition_values=(), alias=None):
62            if alias is None:
63                alias = table
64            join_dict = dict(table=table,
65                             condition=condition,
66                             condition_values=condition_values,
67                             join_type=join_type,
68                             alias=alias)
69            self._custom_joins.append(join_dict)
70
71
72        @classmethod
73        def convert_query(self, query_set):
74            """
75            Convert the query set's "query" attribute to a CustomQuery.
76            """
77            # Make a copy of the query set
78            query_set = query_set.all()
79            query_set.query = query_set.query.clone(
80                    klass=ExtendedManager.CustomQuery,
81                    _custom_joins=[])
82            return query_set
83
84
85    class _WhereClause(object):
86        """Object allowing us to inject arbitrary SQL into Django queries.
87
88        By using this instead of extra(where=...), we can still freely combine
89        queries with & and |.
90        """
91        def __init__(self, clause, values=()):
92            self._clause = clause
93            self._values = values
94
95
96        def as_sql(self, qn=None, connection=None):
97            return self._clause, self._values
98
99
100        def relabel_aliases(self, change_map):
101            return
102
103
104    def add_join(self, query_set, join_table, join_key, join_condition='',
105                 join_condition_values=(), join_from_key=None, alias=None,
106                 suffix='', exclude=False, force_left_join=False):
107        """Add a join to query_set.
108
109        Join looks like this:
110                (INNER|LEFT) JOIN <join_table> AS <alias>
111                    ON (<this table>.<join_from_key> = <join_table>.<join_key>
112                        and <join_condition>)
113
114        @param join_table table to join to
115        @param join_key field referencing back to this model to use for the join
116        @param join_condition extra condition for the ON clause of the join
117        @param join_condition_values values to substitute into join_condition
118        @param join_from_key column on this model to join from.
119        @param alias alias to use for for join
120        @param suffix suffix to add to join_table for the join alias, if no
121                alias is provided
122        @param exclude if true, exclude rows that match this join (will use a
123        LEFT OUTER JOIN and an appropriate WHERE condition)
124        @param force_left_join - if true, a LEFT OUTER JOIN will be used
125        instead of an INNER JOIN regardless of other options
126        """
127        join_from_table = query_set.model._meta.db_table
128        if join_from_key is None:
129            join_from_key = self.model._meta.pk.name
130        if alias is None:
131            alias = join_table + suffix
132        full_join_key = _quote_name(alias) + '.' + _quote_name(join_key)
133        full_join_condition = '%s = %s.%s' % (full_join_key,
134                                              _quote_name(join_from_table),
135                                              _quote_name(join_from_key))
136        if join_condition:
137            full_join_condition += ' AND (' + join_condition + ')'
138        if exclude or force_left_join:
139            join_type = query_set.query.LOUTER
140        else:
141            join_type = query_set.query.INNER
142
143        query_set = self.CustomQuery.convert_query(query_set)
144        query_set.query.add_custom_join(join_table,
145                                        full_join_condition,
146                                        join_type,
147                                        condition_values=join_condition_values,
148                                        alias=alias)
149
150        if exclude:
151            query_set = query_set.extra(where=[full_join_key + ' IS NULL'])
152
153        return query_set
154
155
156    def _info_for_many_to_one_join(self, field, join_to_query, alias):
157        """
158        @param field: the ForeignKey field on the related model
159        @param join_to_query: the query over the related model that we're
160                joining to
161        @param alias: alias of joined table
162        """
163        info = {}
164        rhs_table = join_to_query.model._meta.db_table
165        info['rhs_table'] = rhs_table
166        info['rhs_column'] = field.column
167        info['lhs_column'] = field.rel.get_related_field().column
168        rhs_where = join_to_query.query.where
169        rhs_where.relabel_aliases({rhs_table: alias})
170        compiler = join_to_query.query.get_compiler(using=join_to_query.db)
171        initial_clause, values = compiler.as_sql()
172        all_clauses = (initial_clause,)
173        if hasattr(join_to_query.query, 'extra_where'):
174            all_clauses += join_to_query.query.extra_where
175        info['where_clause'] = (
176                    ' AND '.join('(%s)' % clause for clause in all_clauses))
177        info['values'] = values
178        return info
179
180
181    def _info_for_many_to_many_join(self, m2m_field, join_to_query, alias,
182                                    m2m_is_on_this_model):
183        """
184        @param m2m_field: a Django field representing the M2M relationship.
185                It uses a pivot table with the following structure:
186                this model table <---> M2M pivot table <---> joined model table
187        @param join_to_query: the query over the related model that we're
188                joining to.
189        @param alias: alias of joined table
190        """
191        if m2m_is_on_this_model:
192            # referenced field on this model
193            lhs_id_field = self.model._meta.pk
194            # foreign key on the pivot table referencing lhs_id_field
195            m2m_lhs_column = m2m_field.m2m_column_name()
196            # foreign key on the pivot table referencing rhd_id_field
197            m2m_rhs_column = m2m_field.m2m_reverse_name()
198            # referenced field on related model
199            rhs_id_field = m2m_field.rel.get_related_field()
200        else:
201            lhs_id_field = m2m_field.rel.get_related_field()
202            m2m_lhs_column = m2m_field.m2m_reverse_name()
203            m2m_rhs_column = m2m_field.m2m_column_name()
204            rhs_id_field = join_to_query.model._meta.pk
205
206        info = {}
207        info['rhs_table'] = m2m_field.m2m_db_table()
208        info['rhs_column'] = m2m_lhs_column
209        info['lhs_column'] = lhs_id_field.column
210
211        # select the ID of related models relevant to this join.  we can only do
212        # a single join, so we need to gather this information up front and
213        # include it in the join condition.
214        rhs_ids = join_to_query.values_list(rhs_id_field.attname, flat=True)
215        assert len(rhs_ids) == 1, ('Many-to-many custom field joins can only '
216                                   'match a single related object.')
217        rhs_id = rhs_ids[0]
218
219        info['where_clause'] = '%s.%s = %s' % (_quote_name(alias),
220                                               _quote_name(m2m_rhs_column),
221                                               rhs_id)
222        info['values'] = ()
223        return info
224
225
226    def join_custom_field(self, query_set, join_to_query, alias,
227                          left_join=True):
228        """Join to a related model to create a custom field in the given query.
229
230        This method is used to construct a custom field on the given query based
231        on a many-valued relationsip.  join_to_query should be a simple query
232        (no joins) on the related model which returns at most one related row
233        per instance of this model.
234
235        For many-to-one relationships, the joined table contains the matching
236        row from the related model it one is related, NULL otherwise.
237
238        For many-to-many relationships, the joined table contains the matching
239        row if it's related, NULL otherwise.
240        """
241        relationship_type, field = self.determine_relationship(
242                join_to_query.model)
243
244        if relationship_type == self.MANY_TO_ONE:
245            info = self._info_for_many_to_one_join(field, join_to_query, alias)
246        elif relationship_type == self.M2M_ON_RELATED_MODEL:
247            info = self._info_for_many_to_many_join(
248                    m2m_field=field, join_to_query=join_to_query, alias=alias,
249                    m2m_is_on_this_model=False)
250        elif relationship_type ==self.M2M_ON_THIS_MODEL:
251            info = self._info_for_many_to_many_join(
252                    m2m_field=field, join_to_query=join_to_query, alias=alias,
253                    m2m_is_on_this_model=True)
254
255        return self.add_join(query_set, info['rhs_table'], info['rhs_column'],
256                             join_from_key=info['lhs_column'],
257                             join_condition=info['where_clause'],
258                             join_condition_values=info['values'],
259                             alias=alias,
260                             force_left_join=left_join)
261
262
263    def key_on_joined_table(self, join_to_query):
264        """Get a non-null column on the table joined for the given query.
265
266        This analyzes the join that would be produced if join_to_query were
267        passed to join_custom_field.
268        """
269        relationship_type, field = self.determine_relationship(
270                join_to_query.model)
271        if relationship_type == self.MANY_TO_ONE:
272            return join_to_query.model._meta.pk.column
273        return field.m2m_column_name() # any column on the M2M table will do
274
275
276    def add_where(self, query_set, where, values=()):
277        query_set = query_set.all()
278        query_set.query.where.add(self._WhereClause(where, values),
279                                  django.db.models.sql.where.AND)
280        return query_set
281
282
283    def _get_quoted_field(self, table, field):
284        return _quote_name(table) + '.' + _quote_name(field)
285
286
287    def get_key_on_this_table(self, key_field=None):
288        if key_field is None:
289            # default to primary key
290            key_field = self.model._meta.pk.column
291        return self._get_quoted_field(self.model._meta.db_table, key_field)
292
293
294    def escape_user_sql(self, sql):
295        return sql.replace('%', '%%')
296
297
298    def _custom_select_query(self, query_set, selects):
299        """Execute a custom select query.
300
301        @param query_set: query set as returned by query_objects.
302        @param selects: Tables/Columns to select, e.g. tko_test_labels_list.id.
303
304        @returns: Result of the query as returned by cursor.fetchall().
305        """
306        compiler = query_set.query.get_compiler(using=query_set.db)
307        sql, params = compiler.as_sql()
308        from_ = sql[sql.find(' FROM'):]
309
310        if query_set.query.distinct:
311            distinct = 'DISTINCT '
312        else:
313            distinct = ''
314
315        sql_query = ('SELECT ' + distinct + ','.join(selects) + from_)
316        # Chose the connection that's responsible for this type of object
317        cursor = connections[query_set.db].cursor()
318        cursor.execute(sql_query, params)
319        return cursor.fetchall()
320
321
322    def _is_relation_to(self, field, model_class):
323        return field.rel and field.rel.to is model_class
324
325
326    MANY_TO_ONE = object()
327    M2M_ON_RELATED_MODEL = object()
328    M2M_ON_THIS_MODEL = object()
329
330    def determine_relationship(self, related_model):
331        """
332        Determine the relationship between this model and related_model.
333
334        related_model must have some sort of many-valued relationship to this
335        manager's model.
336        @returns (relationship_type, field), where relationship_type is one of
337                MANY_TO_ONE, M2M_ON_RELATED_MODEL, M2M_ON_THIS_MODEL, and field
338                is the Django field object for the relationship.
339        """
340        # look for a foreign key field on related_model relating to this model
341        for field in related_model._meta.fields:
342            if self._is_relation_to(field, self.model):
343                return self.MANY_TO_ONE, field
344
345        # look for an M2M field on related_model relating to this model
346        for field in related_model._meta.many_to_many:
347            if self._is_relation_to(field, self.model):
348                return self.M2M_ON_RELATED_MODEL, field
349
350        # maybe this model has the many-to-many field
351        for field in self.model._meta.many_to_many:
352            if self._is_relation_to(field, related_model):
353                return self.M2M_ON_THIS_MODEL, field
354
355        raise ValueError('%s has no relation to %s' %
356                         (related_model, self.model))
357
358
359    def _get_pivot_iterator(self, base_objects_by_id, related_model):
360        """
361        Determine the relationship between this model and related_model, and
362        return a pivot iterator.
363        @param base_objects_by_id: dict of instances of this model indexed by
364        their IDs
365        @returns a pivot iterator, which yields a tuple (base_object,
366        related_object) for each relationship between a base object and a
367        related object.  all base_object instances come from base_objects_by_id.
368        Note -- this depends on Django model internals.
369        """
370        relationship_type, field = self.determine_relationship(related_model)
371        if relationship_type == self.MANY_TO_ONE:
372            return self._many_to_one_pivot(base_objects_by_id,
373                                           related_model, field)
374        elif relationship_type == self.M2M_ON_RELATED_MODEL:
375            return self._many_to_many_pivot(
376                    base_objects_by_id, related_model, field.m2m_db_table(),
377                    field.m2m_reverse_name(), field.m2m_column_name())
378        else:
379            assert relationship_type == self.M2M_ON_THIS_MODEL
380            return self._many_to_many_pivot(
381                    base_objects_by_id, related_model, field.m2m_db_table(),
382                    field.m2m_column_name(), field.m2m_reverse_name())
383
384
385    def _many_to_one_pivot(self, base_objects_by_id, related_model,
386                           foreign_key_field):
387        """
388        @returns a pivot iterator - see _get_pivot_iterator()
389        """
390        filter_data = {foreign_key_field.name + '__pk__in':
391                       base_objects_by_id.keys()}
392        for related_object in related_model.objects.filter(**filter_data):
393            # lookup base object in the dict, rather than grabbing it from the
394            # related object.  we need to return instances from the dict, not
395            # fresh instances of the same models (and grabbing model instances
396            # from the related models incurs a DB query each time).
397            base_object_id = getattr(related_object, foreign_key_field.attname)
398            base_object = base_objects_by_id[base_object_id]
399            yield base_object, related_object
400
401
402    def _query_pivot_table(self, base_objects_by_id, pivot_table,
403                           pivot_from_field, pivot_to_field, related_model):
404        """
405        @param id_list list of IDs of self.model objects to include
406        @param pivot_table the name of the pivot table
407        @param pivot_from_field a field name on pivot_table referencing
408        self.model
409        @param pivot_to_field a field name on pivot_table referencing the
410        related model.
411        @param related_model the related model
412
413        @returns pivot list of IDs (base_id, related_id)
414        """
415        query = """
416        SELECT %(from_field)s, %(to_field)s
417        FROM %(table)s
418        WHERE %(from_field)s IN (%(id_list)s)
419        """ % dict(from_field=pivot_from_field,
420                   to_field=pivot_to_field,
421                   table=pivot_table,
422                   id_list=','.join(str(id_) for id_
423                                    in base_objects_by_id.iterkeys()))
424
425        # Chose the connection that's responsible for this type of object
426        # The databases for related_model and the current model will always
427        # be the same, related_model is just easier to obtain here because
428        # self is only a ExtendedManager, not the object.
429        cursor = connections[related_model.objects.db].cursor()
430        cursor.execute(query)
431        return cursor.fetchall()
432
433
434    def _many_to_many_pivot(self, base_objects_by_id, related_model,
435                            pivot_table, pivot_from_field, pivot_to_field):
436        """
437        @param pivot_table: see _query_pivot_table
438        @param pivot_from_field: see _query_pivot_table
439        @param pivot_to_field: see _query_pivot_table
440        @returns a pivot iterator - see _get_pivot_iterator()
441        """
442        id_pivot = self._query_pivot_table(base_objects_by_id, pivot_table,
443                                           pivot_from_field, pivot_to_field,
444                                           related_model)
445
446        all_related_ids = list(set(related_id for base_id, related_id
447                                   in id_pivot))
448        related_objects_by_id = related_model.objects.in_bulk(all_related_ids)
449
450        for base_id, related_id in id_pivot:
451            yield base_objects_by_id[base_id], related_objects_by_id[related_id]
452
453
454    def populate_relationships(self, base_objects, related_model,
455                               related_list_name):
456        """
457        For each instance of this model in base_objects, add a field named
458        related_list_name listing all the related objects of type related_model.
459        related_model must be in a many-to-one or many-to-many relationship with
460        this model.
461        @param base_objects - list of instances of this model
462        @param related_model - model class related to this model
463        @param related_list_name - attribute name in which to store the related
464        object list.
465        """
466        if not base_objects:
467            # if we don't bail early, we'll get a SQL error later
468            return
469
470        base_objects_by_id = dict((base_object._get_pk_val(), base_object)
471                                  for base_object in base_objects)
472        pivot_iterator = self._get_pivot_iterator(base_objects_by_id,
473                                                  related_model)
474
475        for base_object in base_objects:
476            setattr(base_object, related_list_name, [])
477
478        for base_object, related_object in pivot_iterator:
479            getattr(base_object, related_list_name).append(related_object)
480
481
482class ModelWithInvalidQuerySet(dbmodels.query.QuerySet):
483    """
484    QuerySet that handles delete() properly for models with an "invalid" bit
485    """
486    def delete(self):
487        for model in self:
488            model.delete()
489
490
491class ModelWithInvalidManager(ExtendedManager):
492    """
493    Manager for objects with an "invalid" bit
494    """
495    def get_query_set(self):
496        return ModelWithInvalidQuerySet(self.model)
497
498
499class ValidObjectsManager(ModelWithInvalidManager):
500    """
501    Manager returning only objects with invalid=False.
502    """
503    def get_query_set(self):
504        queryset = super(ValidObjectsManager, self).get_query_set()
505        return queryset.filter(invalid=False)
506
507
508class ModelExtensions(rdb_model_extensions.ModelValidators):
509    """\
510    Mixin with convenience functions for models, built on top of
511    the model validators in rdb_model_extensions.
512    """
513    # TODO: at least some of these functions really belong in a custom
514    # Manager class
515
516
517    SERIALIZATION_LINKS_TO_FOLLOW = set()
518    """
519    To be able to send jobs and hosts to shards, it's necessary to find their
520    dependencies.
521    The most generic approach for this would be to traverse all relationships
522    to other objects recursively. This would list all objects that are related
523    in any way.
524    But this approach finds too many objects: If a host should be transferred,
525    all it's relationships would be traversed. This would find an acl group.
526    If then the acl group's relationships are traversed, the relationship
527    would be followed backwards and many other hosts would be found.
528
529    This mapping tells that algorithm which relations to follow explicitly.
530    """
531
532
533    SERIALIZATION_LINKS_TO_KEEP = set()
534    """This set stores foreign keys which we don't want to follow, but
535    still want to include in the serialized dictionary. For
536    example, we follow the relationship `Host.hostattribute_set`,
537    but we do not want to follow `HostAttributes.host_id` back to
538    to Host, which would otherwise lead to a circle. However, we still
539    like to serialize HostAttribute.`host_id`."""
540
541    SERIALIZATION_LOCAL_LINKS_TO_UPDATE = set()
542    """
543    On deserializion, if the object to persist already exists, local fields
544    will only be updated, if their name is in this set.
545    """
546
547
548    @classmethod
549    def convert_human_readable_values(cls, data, to_human_readable=False):
550        """\
551        Performs conversions on user-supplied field data, to make it
552        easier for users to pass human-readable data.
553
554        For all fields that have choice sets, convert their values
555        from human-readable strings to enum values, if necessary.  This
556        allows users to pass strings instead of the corresponding
557        integer values.
558
559        For all foreign key fields, call smart_get with the supplied
560        data.  This allows the user to pass either an ID value or
561        the name of the object as a string.
562
563        If to_human_readable=True, perform the inverse - i.e. convert
564        numeric values to human readable values.
565
566        This method modifies data in-place.
567        """
568        field_dict = cls.get_field_dict()
569        for field_name in data:
570            if field_name not in field_dict or data[field_name] is None:
571                continue
572            field_obj = field_dict[field_name]
573            # convert enum values
574            if field_obj.choices:
575                for choice_data in field_obj.choices:
576                    # choice_data is (value, name)
577                    if to_human_readable:
578                        from_val, to_val = choice_data
579                    else:
580                        to_val, from_val = choice_data
581                    if from_val == data[field_name]:
582                        data[field_name] = to_val
583                        break
584            # convert foreign key values
585            elif field_obj.rel:
586                dest_obj = field_obj.rel.to.smart_get(data[field_name],
587                                                      valid_only=False)
588                if to_human_readable:
589                    # parameterized_jobs do not have a name_field
590                    if (field_name != 'parameterized_job' and
591                        dest_obj.name_field is not None):
592                        data[field_name] = getattr(dest_obj,
593                                                   dest_obj.name_field)
594                else:
595                    data[field_name] = dest_obj
596
597
598
599
600    def _validate_unique(self):
601        """\
602        Validate that unique fields are unique.  Django manipulators do
603        this too, but they're a huge pain to use manually.  Trust me.
604        """
605        errors = {}
606        cls = type(self)
607        field_dict = self.get_field_dict()
608        manager = cls.get_valid_manager()
609        for field_name, field_obj in field_dict.iteritems():
610            if not field_obj.unique:
611                continue
612
613            value = getattr(self, field_name)
614            if value is None and field_obj.auto_created:
615                # don't bother checking autoincrement fields about to be
616                # generated
617                continue
618
619            existing_objs = manager.filter(**{field_name : value})
620            num_existing = existing_objs.count()
621
622            if num_existing == 0:
623                continue
624            if num_existing == 1 and existing_objs[0].id == self.id:
625                continue
626            errors[field_name] = (
627                'This value must be unique (%s)' % (value))
628        return errors
629
630
631    def _validate(self):
632        """
633        First coerces all fields on this instance to their proper Python types.
634        Then runs validation on every field. Returns a dictionary of
635        field_name -> error_list.
636
637        Based on validate() from django.db.models.Model in Django 0.96, which
638        was removed in Django 1.0. It should reappear in a later version. See:
639            http://code.djangoproject.com/ticket/6845
640        """
641        error_dict = {}
642        for f in self._meta.fields:
643            try:
644                python_value = f.to_python(
645                    getattr(self, f.attname, f.get_default()))
646            except django.core.exceptions.ValidationError, e:
647                error_dict[f.name] = str(e)
648                continue
649
650            if not f.blank and not python_value:
651                error_dict[f.name] = 'This field is required.'
652                continue
653
654            setattr(self, f.attname, python_value)
655
656        return error_dict
657
658
659    def do_validate(self):
660        errors = self._validate()
661        unique_errors = self._validate_unique()
662        for field_name, error in unique_errors.iteritems():
663            errors.setdefault(field_name, error)
664        if errors:
665            raise ValidationError(errors)
666
667
668    # actually (externally) useful methods follow
669
670    @classmethod
671    def add_object(cls, data={}, **kwargs):
672        """\
673        Returns a new object created with the given data (a dictionary
674        mapping field names to values). Merges any extra keyword args
675        into data.
676        """
677        data = dict(data)
678        data.update(kwargs)
679        data = cls.prepare_data_args(data)
680        cls.convert_human_readable_values(data)
681        data = cls.provide_default_values(data)
682
683        obj = cls(**data)
684        obj.do_validate()
685        obj.save()
686        return obj
687
688
689    def update_object(self, data={}, **kwargs):
690        """\
691        Updates the object with the given data (a dictionary mapping
692        field names to values).  Merges any extra keyword args into
693        data.
694        """
695        data = dict(data)
696        data.update(kwargs)
697        data = self.prepare_data_args(data)
698        self.convert_human_readable_values(data)
699        for field_name, value in data.iteritems():
700            setattr(self, field_name, value)
701        self.do_validate()
702        self.save()
703
704
705    # see query_objects()
706    _SPECIAL_FILTER_KEYS = ('query_start', 'query_limit', 'sort_by',
707                            'extra_args', 'extra_where', 'no_distinct')
708
709
710    @classmethod
711    def _extract_special_params(cls, filter_data):
712        """
713        @returns a tuple of dicts (special_params, regular_filters), where
714        special_params contains the parameters we handle specially and
715        regular_filters is the remaining data to be handled by Django.
716        """
717        regular_filters = dict(filter_data)
718        special_params = {}
719        for key in cls._SPECIAL_FILTER_KEYS:
720            if key in regular_filters:
721                special_params[key] = regular_filters.pop(key)
722        return special_params, regular_filters
723
724
725    @classmethod
726    def apply_presentation(cls, query, filter_data):
727        """
728        Apply presentation parameters -- sorting and paging -- to the given
729        query.
730        @returns new query with presentation applied
731        """
732        special_params, _ = cls._extract_special_params(filter_data)
733        sort_by = special_params.get('sort_by', None)
734        if sort_by:
735            assert isinstance(sort_by, list) or isinstance(sort_by, tuple)
736            query = query.extra(order_by=sort_by)
737
738        query_start = special_params.get('query_start', None)
739        query_limit = special_params.get('query_limit', None)
740        if query_start is not None:
741            if query_limit is None:
742                raise ValueError('Cannot pass query_start without query_limit')
743            # query_limit is passed as a page size
744            query_limit += query_start
745        return query[query_start:query_limit]
746
747
748    @classmethod
749    def query_objects(cls, filter_data, valid_only=True, initial_query=None,
750                      apply_presentation=True):
751        """\
752        Returns a QuerySet object for querying the given model_class
753        with the given filter_data.  Optional special arguments in
754        filter_data include:
755        -query_start: index of first return to return
756        -query_limit: maximum number of results to return
757        -sort_by: list of fields to sort on.  prefixing a '-' onto a
758         field name changes the sort to descending order.
759        -extra_args: keyword args to pass to query.extra() (see Django
760         DB layer documentation)
761        -extra_where: extra WHERE clause to append
762        -no_distinct: if True, a DISTINCT will not be added to the SELECT
763        """
764        special_params, regular_filters = cls._extract_special_params(
765                filter_data)
766
767        if initial_query is None:
768            if valid_only:
769                initial_query = cls.get_valid_manager()
770            else:
771                initial_query = cls.objects
772
773        query = initial_query.filter(**regular_filters)
774
775        use_distinct = not special_params.get('no_distinct', False)
776        if use_distinct:
777            query = query.distinct()
778
779        extra_args = special_params.get('extra_args', {})
780        extra_where = special_params.get('extra_where', None)
781        if extra_where:
782            # escape %'s
783            extra_where = cls.objects.escape_user_sql(extra_where)
784            extra_args.setdefault('where', []).append(extra_where)
785        if extra_args:
786            query = query.extra(**extra_args)
787            # TODO: Use readonly connection for these queries.
788            # This has been disabled, because it's not used anyway, as the
789            # configured readonly user is the same as the real user anyway.
790
791        if apply_presentation:
792            query = cls.apply_presentation(query, filter_data)
793
794        return query
795
796
797    @classmethod
798    def query_count(cls, filter_data, initial_query=None):
799        """\
800        Like query_objects, but retreive only the count of results.
801        """
802        filter_data.pop('query_start', None)
803        filter_data.pop('query_limit', None)
804        query = cls.query_objects(filter_data, initial_query=initial_query)
805        return query.count()
806
807
808    @classmethod
809    def clean_object_dicts(cls, field_dicts):
810        """\
811        Take a list of dicts corresponding to object (as returned by
812        query.values()) and clean the data to be more suitable for
813        returning to the user.
814        """
815        for field_dict in field_dicts:
816            cls.clean_foreign_keys(field_dict)
817            cls._convert_booleans(field_dict)
818            cls.convert_human_readable_values(field_dict,
819                                              to_human_readable=True)
820
821
822    @classmethod
823    def list_objects(cls, filter_data, initial_query=None):
824        """\
825        Like query_objects, but return a list of dictionaries.
826        """
827        query = cls.query_objects(filter_data, initial_query=initial_query)
828        extra_fields = query.query.extra_select.keys()
829        field_dicts = [model_object.get_object_dict(extra_fields=extra_fields)
830                       for model_object in query]
831        return field_dicts
832
833
834    @classmethod
835    def smart_get(cls, id_or_name, valid_only=True):
836        """\
837        smart_get(integer) -> get object by ID
838        smart_get(string) -> get object by name_field
839        """
840        if valid_only:
841            manager = cls.get_valid_manager()
842        else:
843            manager = cls.objects
844
845        if isinstance(id_or_name, (int, long)):
846            return manager.get(pk=id_or_name)
847        if isinstance(id_or_name, basestring) and hasattr(cls, 'name_field'):
848            return manager.get(**{cls.name_field : id_or_name})
849        raise ValueError(
850            'Invalid positional argument: %s (%s)' % (id_or_name,
851                                                      type(id_or_name)))
852
853
854    @classmethod
855    def smart_get_bulk(cls, id_or_name_list):
856        invalid_inputs = []
857        result_objects = []
858        for id_or_name in id_or_name_list:
859            try:
860                result_objects.append(cls.smart_get(id_or_name))
861            except cls.DoesNotExist:
862                invalid_inputs.append(id_or_name)
863        if invalid_inputs:
864            raise cls.DoesNotExist('The following %ss do not exist: %s'
865                                   % (cls.__name__.lower(),
866                                      ', '.join(invalid_inputs)))
867        return result_objects
868
869
870    def get_object_dict(self, extra_fields=None):
871        """\
872        Return a dictionary mapping fields to this object's values.  @param
873        extra_fields: list of extra attribute names to include, in addition to
874        the fields defined on this object.
875        """
876        fields = self.get_field_dict().keys()
877        if extra_fields:
878            fields += extra_fields
879        object_dict = dict((field_name, getattr(self, field_name))
880                           for field_name in fields)
881        self.clean_object_dicts([object_dict])
882        self._postprocess_object_dict(object_dict)
883        return object_dict
884
885
886    def _postprocess_object_dict(self, object_dict):
887        """For subclasses to override."""
888        pass
889
890
891    @classmethod
892    def get_valid_manager(cls):
893        return cls.objects
894
895
896    def _record_attributes(self, attributes):
897        """
898        See on_attribute_changed.
899        """
900        assert not isinstance(attributes, basestring)
901        self._recorded_attributes = dict((attribute, getattr(self, attribute))
902                                         for attribute in attributes)
903
904
905    def _check_for_updated_attributes(self):
906        """
907        See on_attribute_changed.
908        """
909        for attribute, original_value in self._recorded_attributes.iteritems():
910            new_value = getattr(self, attribute)
911            if original_value != new_value:
912                self.on_attribute_changed(attribute, original_value)
913        self._record_attributes(self._recorded_attributes.keys())
914
915
916    def on_attribute_changed(self, attribute, old_value):
917        """
918        Called whenever an attribute is updated.  To be overridden.
919
920        To use this method, you must:
921        * call _record_attributes() from __init__() (after making the super
922        call) with a list of attributes for which you want to be notified upon
923        change.
924        * call _check_for_updated_attributes() from save().
925        """
926        pass
927
928
929    def serialize(self, include_dependencies=True):
930        """Serializes the object with dependencies.
931
932        The variable SERIALIZATION_LINKS_TO_FOLLOW defines which dependencies
933        this function will serialize with the object.
934
935        @param include_dependencies: Whether or not to follow relations to
936                                     objects this object depends on.
937                                     This parameter is used when uploading
938                                     jobs from a shard to the master, as the
939                                     master already has all the dependent
940                                     objects.
941
942        @returns: Dictionary representation of the object.
943        """
944        serialized = {}
945        timer = autotest_stats.Timer('serialize_latency.%s' % (
946                type(self).__name__))
947        with timer.get_client('local'):
948            for field in self._meta.concrete_model._meta.local_fields:
949                if field.rel is None:
950                    serialized[field.name] = field._get_val_from_obj(self)
951                elif field.name in self.SERIALIZATION_LINKS_TO_KEEP:
952                    # attname will contain "_id" suffix for foreign keys,
953                    # e.g. HostAttribute.host will be serialized as 'host_id'.
954                    # Use it for easy deserialization.
955                    serialized[field.attname] = field._get_val_from_obj(self)
956
957        if include_dependencies:
958            with timer.get_client('related'):
959                for link in self.SERIALIZATION_LINKS_TO_FOLLOW:
960                    serialized[link] = self._serialize_relation(link)
961
962        return serialized
963
964
965    def _serialize_relation(self, link):
966        """Serializes dependent objects given the name of the relation.
967
968        @param link: Name of the relation to take objects from.
969
970        @returns For To-Many relationships a list of the serialized related
971            objects, for To-One relationships the serialized related object.
972        """
973        try:
974            attr = getattr(self, link)
975        except AttributeError:
976            # One-To-One relationships that point to None may raise this
977            return None
978
979        if attr is None:
980            return None
981        if hasattr(attr, 'all'):
982            return [obj.serialize() for obj in attr.all()]
983        return attr.serialize()
984
985
986    @classmethod
987    def _split_local_from_foreign_values(cls, data):
988        """This splits local from foreign values in a serialized object.
989
990        @param data: The serialized object.
991
992        @returns A tuple of two lists, both containing tuples in the form
993                 (link_name, link_value). The first list contains all links
994                 for local fields, the second one contains those for foreign
995                 fields/objects.
996        """
997        links_to_local_values, links_to_related_values = [], []
998        for link, value in data.iteritems():
999            if link in cls.SERIALIZATION_LINKS_TO_FOLLOW:
1000                # It's a foreign key
1001                links_to_related_values.append((link, value))
1002            else:
1003                # It's a local attribute or a foreign key
1004                # we don't want to follow.
1005                links_to_local_values.append((link, value))
1006        return links_to_local_values, links_to_related_values
1007
1008
1009    @classmethod
1010    def _filter_update_allowed_fields(cls, data):
1011        """Filters data and returns only files that updates are allowed on.
1012
1013        This is i.e. needed for syncing aborted bits from the master to shards.
1014
1015        Local links are only allowed to be updated, if they are in
1016        SERIALIZATION_LOCAL_LINKS_TO_UPDATE.
1017        Overwriting existing values is allowed in order to be able to sync i.e.
1018        the aborted bit from the master to a shard.
1019
1020        The whitelisting mechanism is in place to prevent overwriting local
1021        status: If all fields were overwritten, jobs would be completely be
1022        set back to their original (unstarted) state.
1023
1024        @param data: List with tuples of the form (link_name, link_value), as
1025                     returned by _split_local_from_foreign_values.
1026
1027        @returns List of the same format as data, but only containing data for
1028                 fields that updates are allowed on.
1029        """
1030        return [pair for pair in data
1031                if pair[0] in cls.SERIALIZATION_LOCAL_LINKS_TO_UPDATE]
1032
1033
1034    @classmethod
1035    def delete_matching_record(cls, **filter_args):
1036        """Delete records matching the filter.
1037
1038        @param filter_args: Arguments for the django filter
1039                used to locate the record to delete.
1040        """
1041        try:
1042            existing_record = cls.objects.get(**filter_args)
1043        except cls.DoesNotExist:
1044            return
1045        existing_record.delete()
1046
1047
1048    def _deserialize_local(self, data):
1049        """Set local attributes from a list of tuples.
1050
1051        @param data: List of tuples like returned by
1052                     _split_local_from_foreign_values.
1053        """
1054        if not data:
1055            return
1056
1057        for link, value in data:
1058            setattr(self, link, value)
1059        # Overwridden save() methods are prone to errors, so don't execute them.
1060        # This is because:
1061        # - the overwritten methods depend on ACL groups that don't yet exist
1062        #   and don't handle errors
1063        # - the overwritten methods think this object already exists in the db
1064        #   because the id is already set
1065        super(type(self), self).save()
1066
1067
1068    def _deserialize_relations(self, data):
1069        """Set foreign attributes from a list of tuples.
1070
1071        This deserialized the related objects using their own deserialize()
1072        function and then sets the relation.
1073
1074        @param data: List of tuples like returned by
1075                     _split_local_from_foreign_values.
1076        """
1077        for link, value in data:
1078            self._deserialize_relation(link, value)
1079        # See comment in _deserialize_local
1080        super(type(self), self).save()
1081
1082
1083    @classmethod
1084    def get_record(cls, data):
1085        """Retrieve a record with the data in the given input arg.
1086
1087        @param data: A dictionary containing the information to use in a query
1088                for data. If child models have different constraints of
1089                uniqueness they should override this model.
1090
1091        @return: An object with matching data.
1092
1093        @raises DoesNotExist: If a record with the given data doesn't exist.
1094        """
1095        return cls.objects.get(id=data['id'])
1096
1097
1098    @classmethod
1099    def deserialize(cls, data):
1100        """Recursively deserializes and saves an object with it's dependencies.
1101
1102        This takes the result of the serialize method and creates objects
1103        in the database that are just like the original.
1104
1105        If an object of the same type with the same id already exists, it's
1106        local values will be left untouched, unless they are explicitly
1107        whitelisted in SERIALIZATION_LOCAL_LINKS_TO_UPDATE.
1108
1109        Deserialize will always recursively propagate to all related objects
1110        present in data though.
1111        I.e. this is necessary to add users to an already existing acl-group.
1112
1113        @param data: Representation of an object and its dependencies, as
1114                     returned by serialize.
1115
1116        @returns: The object represented by data if it didn't exist before,
1117                  otherwise the object that existed before and has the same type
1118                  and id as the one described by data.
1119        """
1120        if data is None:
1121            return None
1122
1123        local, related = cls._split_local_from_foreign_values(data)
1124        try:
1125            instance = cls.get_record(data)
1126            local = cls._filter_update_allowed_fields(local)
1127        except cls.DoesNotExist:
1128            instance = cls()
1129
1130        timer = autotest_stats.Timer('deserialize_latency.%s' % (
1131                type(instance).__name__))
1132        with timer.get_client('local'):
1133            instance._deserialize_local(local)
1134        with timer.get_client('related'):
1135            instance._deserialize_relations(related)
1136
1137        return instance
1138
1139
1140    def sanity_check_update_from_shard(self, shard, updated_serialized,
1141                                       *args, **kwargs):
1142        """Check if an update sent from a shard is legitimate.
1143
1144        @raises error.UnallowedRecordsSentToMaster if an update is not
1145                legitimate.
1146        """
1147        raise NotImplementedError(
1148            'sanity_check_update_from_shard must be implemented by subclass %s '
1149            'for type %s' % type(self))
1150
1151
1152    @transaction.commit_on_success
1153    def update_from_serialized(self, serialized):
1154        """Updates local fields of an existing object from a serialized form.
1155
1156        This is different than the normal deserialize() in the way that it
1157        does update local values, which deserialize doesn't, but doesn't
1158        recursively propagate to related objects, which deserialize() does.
1159
1160        The use case of this function is to update job records on the master
1161        after the jobs have been executed on a slave, as the master is not
1162        interested in updates for users, labels, specialtasks, etc.
1163
1164        @param serialized: Representation of an object and its dependencies, as
1165                           returned by serialize.
1166
1167        @raises ValueError: if serialized contains related objects, i.e. not
1168                            only local fields.
1169        """
1170        local, related = (
1171            self._split_local_from_foreign_values(serialized))
1172        if related:
1173            raise ValueError('Serialized must not contain foreign '
1174                             'objects: %s' % related)
1175
1176        self._deserialize_local(local)
1177
1178
1179    def custom_deserialize_relation(self, link, data):
1180        """Allows overriding the deserialization behaviour by subclasses."""
1181        raise NotImplementedError(
1182            'custom_deserialize_relation must be implemented by subclass %s '
1183            'for relation %s' % (type(self), link))
1184
1185
1186    def _deserialize_relation(self, link, data):
1187        """Deserializes related objects and sets references on this object.
1188
1189        Relations that point to a list of objects are handled automatically.
1190        For many-to-one or one-to-one relations custom_deserialize_relation
1191        must be overridden by the subclass.
1192
1193        Related objects are deserialized using their deserialize() method.
1194        Thereby they and their dependencies are created if they don't exist
1195        and saved to the database.
1196
1197        @param link: Name of the relation.
1198        @param data: Serialized representation of the related object(s).
1199                     This means a list of dictionaries for to-many relations,
1200                     just a dictionary for to-one relations.
1201        """
1202        field = getattr(self, link)
1203
1204        if field and hasattr(field, 'all'):
1205            self._deserialize_2m_relation(link, data, field.model)
1206        else:
1207            self.custom_deserialize_relation(link, data)
1208
1209
1210    def _deserialize_2m_relation(self, link, data, related_class):
1211        """Deserialize related objects for one to-many relationship.
1212
1213        @param link: Name of the relation.
1214        @param data: Serialized representation of the related objects.
1215                     This is a list with of dictionaries.
1216        @param related_class: A class representing a django model, with which
1217                              this class has a one-to-many relationship.
1218        """
1219        relation_set = getattr(self, link)
1220        if related_class == self.get_attribute_model():
1221            # When deserializing a model together with
1222            # its attributes, clear all the exising attributes to ensure
1223            # db consistency. Note 'update' won't be sufficient, as we also
1224            # want to remove any attributes that no longer exist in |data|.
1225            #
1226            # core_filters is a dictionary of filters, defines how
1227            # RelatedMangager would query for the 1-to-many relationship. E.g.
1228            # Host.objects.get(
1229            #     id=20).hostattribute_set.core_filters = {host_id:20}
1230            # We use it to delete objects related to the current object.
1231            related_class.objects.filter(**relation_set.core_filters).delete()
1232        for serialized in data:
1233            relation_set.add(related_class.deserialize(serialized))
1234
1235
1236    @classmethod
1237    def get_attribute_model(cls):
1238        """Return the attribute model.
1239
1240        Subclass with attribute-like model should override this to
1241        return the attribute model class. This method will be
1242        called by _deserialize_2m_relation to determine whether
1243        to clear the one-to-many relations first on deserialization of object.
1244        """
1245        return None
1246
1247
1248class ModelWithInvalid(ModelExtensions):
1249    """
1250    Overrides model methods save() and delete() to support invalidation in
1251    place of actual deletion.  Subclasses must have a boolean "invalid"
1252    field.
1253    """
1254
1255    def save(self, *args, **kwargs):
1256        first_time = (self.id is None)
1257        if first_time:
1258            # see if this object was previously added and invalidated
1259            my_name = getattr(self, self.name_field)
1260            filters = {self.name_field : my_name, 'invalid' : True}
1261            try:
1262                old_object = self.__class__.objects.get(**filters)
1263                self.resurrect_object(old_object)
1264            except self.DoesNotExist:
1265                # no existing object
1266                pass
1267
1268        super(ModelWithInvalid, self).save(*args, **kwargs)
1269
1270
1271    def resurrect_object(self, old_object):
1272        """
1273        Called when self is about to be saved for the first time and is actually
1274        "undeleting" a previously deleted object.  Can be overridden by
1275        subclasses to copy data as desired from the deleted entry (but this
1276        superclass implementation must normally be called).
1277        """
1278        self.id = old_object.id
1279
1280
1281    def clean_object(self):
1282        """
1283        This method is called when an object is marked invalid.
1284        Subclasses should override this to clean up relationships that
1285        should no longer exist if the object were deleted.
1286        """
1287        pass
1288
1289
1290    def delete(self):
1291        self.invalid = self.invalid
1292        assert not self.invalid
1293        self.invalid = True
1294        self.save()
1295        self.clean_object()
1296
1297
1298    @classmethod
1299    def get_valid_manager(cls):
1300        return cls.valid_objects
1301
1302
1303    class Manipulator(object):
1304        """
1305        Force default manipulators to look only at valid objects -
1306        otherwise they will match against invalid objects when checking
1307        uniqueness.
1308        """
1309        @classmethod
1310        def _prepare(cls, model):
1311            super(ModelWithInvalid.Manipulator, cls)._prepare(model)
1312            cls.manager = model.valid_objects
1313
1314
1315class ModelWithAttributes(object):
1316    """
1317    Mixin class for models that have an attribute model associated with them.
1318    The attribute model is assumed to have its value field named "value".
1319    """
1320
1321    def _get_attribute_model_and_args(self, attribute):
1322        """
1323        Subclasses should override this to return a tuple (attribute_model,
1324        keyword_args), where attribute_model is a model class and keyword_args
1325        is a dict of args to pass to attribute_model.objects.get() to get an
1326        instance of the given attribute on this object.
1327        """
1328        raise NotImplementedError
1329
1330
1331    def set_attribute(self, attribute, value):
1332        attribute_model, get_args = self._get_attribute_model_and_args(
1333            attribute)
1334        attribute_object, _ = attribute_model.objects.get_or_create(**get_args)
1335        attribute_object.value = value
1336        attribute_object.save()
1337
1338
1339    def delete_attribute(self, attribute):
1340        attribute_model, get_args = self._get_attribute_model_and_args(
1341            attribute)
1342        try:
1343            attribute_model.objects.get(**get_args).delete()
1344        except attribute_model.DoesNotExist:
1345            pass
1346
1347
1348    def set_or_delete_attribute(self, attribute, value):
1349        if value is None:
1350            self.delete_attribute(attribute)
1351        else:
1352            self.set_attribute(attribute, value)
1353
1354
1355class ModelWithHashManager(dbmodels.Manager):
1356    """Manager for use with the ModelWithHash abstract model class"""
1357
1358    def create(self, **kwargs):
1359        raise Exception('ModelWithHash manager should use get_or_create() '
1360                        'instead of create()')
1361
1362
1363    def get_or_create(self, **kwargs):
1364        kwargs['the_hash'] = self.model._compute_hash(**kwargs)
1365        return super(ModelWithHashManager, self).get_or_create(**kwargs)
1366
1367
1368class ModelWithHash(dbmodels.Model):
1369    """Superclass with methods for dealing with a hash column"""
1370
1371    the_hash = dbmodels.CharField(max_length=40, unique=True)
1372
1373    objects = ModelWithHashManager()
1374
1375    class Meta:
1376        abstract = True
1377
1378
1379    @classmethod
1380    def _compute_hash(cls, **kwargs):
1381        raise NotImplementedError('Subclasses must override _compute_hash()')
1382
1383
1384    def save(self, force_insert=False, **kwargs):
1385        """Prevents saving the model in most cases
1386
1387        We want these models to be immutable, so the generic save() operation
1388        will not work. These models should be instantiated through their the
1389        model.objects.get_or_create() method instead.
1390
1391        The exception is that save(force_insert=True) will be allowed, since
1392        that creates a new row. However, the preferred way to make instances of
1393        these models is through the get_or_create() method.
1394        """
1395        if not force_insert:
1396            # Allow a forced insert to happen; if it's a duplicate, the unique
1397            # constraint will catch it later anyways
1398            raise Exception('ModelWithHash is immutable')
1399        super(ModelWithHash, self).save(force_insert=force_insert, **kwargs)
1400