• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1"""
2Extensions to Django's model logic.
3"""
4
5import django.core.exceptions
6from django.db import connection
7from django.db import connections
8from django.db import models as dbmodels
9from django.db import transaction
10from django.db.models.sql import query
11import django.db.models.sql.where
12
13from autotest_lib.client.common_lib import error
14from autotest_lib.frontend.afe import rdb_model_extensions
15
16
17class ValidationError(django.core.exceptions.ValidationError):
18    """\
19    Data validation error in adding or updating an object. The associated
20    value is a dictionary mapping field names to error strings.
21    """
22
23def _quote_name(name):
24    """Shorthand for connection.ops.quote_name()."""
25    return connection.ops.quote_name(name)
26
27
28class LeasedHostManager(dbmodels.Manager):
29    """Query manager for unleased, unlocked hosts.
30    """
31    def get_query_set(self):
32        return (super(LeasedHostManager, self).get_query_set().filter(
33                leased=0, locked=0))
34
35
36class ExtendedManager(dbmodels.Manager):
37    """\
38    Extended manager supporting subquery filtering.
39    """
40
41    class CustomQuery(query.Query):
42        def __init__(self, *args, **kwargs):
43            super(ExtendedManager.CustomQuery, self).__init__(*args, **kwargs)
44            self._custom_joins = []
45
46
47        def clone(self, klass=None, **kwargs):
48            obj = super(ExtendedManager.CustomQuery, self).clone(klass)
49            obj._custom_joins = list(self._custom_joins)
50            return obj
51
52
53        def combine(self, rhs, connector):
54            super(ExtendedManager.CustomQuery, self).combine(rhs, connector)
55            if hasattr(rhs, '_custom_joins'):
56                self._custom_joins.extend(rhs._custom_joins)
57
58
59        def add_custom_join(self, table, condition, join_type,
60                            condition_values=(), alias=None):
61            if alias is None:
62                alias = table
63            join_dict = dict(table=table,
64                             condition=condition,
65                             condition_values=condition_values,
66                             join_type=join_type,
67                             alias=alias)
68            self._custom_joins.append(join_dict)
69
70
71        @classmethod
72        def convert_query(self, query_set):
73            """
74            Convert the query set's "query" attribute to a CustomQuery.
75            """
76            # Make a copy of the query set
77            query_set = query_set.all()
78            query_set.query = query_set.query.clone(
79                    klass=ExtendedManager.CustomQuery,
80                    _custom_joins=[])
81            return query_set
82
83
84    class _WhereClause(object):
85        """Object allowing us to inject arbitrary SQL into Django queries.
86
87        By using this instead of extra(where=...), we can still freely combine
88        queries with & and |.
89        """
90        def __init__(self, clause, values=()):
91            self._clause = clause
92            self._values = values
93
94
95        def as_sql(self, qn=None, connection=None):
96            return self._clause, self._values
97
98
99        def relabel_aliases(self, change_map):
100            return
101
102
103    def add_join(self, query_set, join_table, join_key, join_condition='',
104                 join_condition_values=(), join_from_key=None, alias=None,
105                 suffix='', exclude=False, force_left_join=False):
106        """Add a join to query_set.
107
108        Join looks like this:
109                (INNER|LEFT) JOIN <join_table> AS <alias>
110                    ON (<this table>.<join_from_key> = <join_table>.<join_key>
111                        and <join_condition>)
112
113        @param join_table table to join to
114        @param join_key field referencing back to this model to use for the join
115        @param join_condition extra condition for the ON clause of the join
116        @param join_condition_values values to substitute into join_condition
117        @param join_from_key column on this model to join from.
118        @param alias alias to use for for join
119        @param suffix suffix to add to join_table for the join alias, if no
120                alias is provided
121        @param exclude if true, exclude rows that match this join (will use a
122        LEFT OUTER JOIN and an appropriate WHERE condition)
123        @param force_left_join - if true, a LEFT OUTER JOIN will be used
124        instead of an INNER JOIN regardless of other options
125        """
126        join_from_table = query_set.model._meta.db_table
127        if join_from_key is None:
128            join_from_key = self.model._meta.pk.name
129        if alias is None:
130            alias = join_table + suffix
131        full_join_key = _quote_name(alias) + '.' + _quote_name(join_key)
132        full_join_condition = '%s = %s.%s' % (full_join_key,
133                                              _quote_name(join_from_table),
134                                              _quote_name(join_from_key))
135        if join_condition:
136            full_join_condition += ' AND (' + join_condition + ')'
137        if exclude or force_left_join:
138            join_type = query_set.query.LOUTER
139        else:
140            join_type = query_set.query.INNER
141
142        query_set = self.CustomQuery.convert_query(query_set)
143        query_set.query.add_custom_join(join_table,
144                                        full_join_condition,
145                                        join_type,
146                                        condition_values=join_condition_values,
147                                        alias=alias)
148
149        if exclude:
150            query_set = query_set.extra(where=[full_join_key + ' IS NULL'])
151
152        return query_set
153
154
155    def _info_for_many_to_one_join(self, field, join_to_query, alias):
156        """
157        @param field: the ForeignKey field on the related model
158        @param join_to_query: the query over the related model that we're
159                joining to
160        @param alias: alias of joined table
161        """
162        info = {}
163        rhs_table = join_to_query.model._meta.db_table
164        info['rhs_table'] = rhs_table
165        info['rhs_column'] = field.column
166        info['lhs_column'] = field.rel.get_related_field().column
167        rhs_where = join_to_query.query.where
168        rhs_where.relabel_aliases({rhs_table: alias})
169        compiler = join_to_query.query.get_compiler(using=join_to_query.db)
170        initial_clause, values = compiler.as_sql()
171        # initial_clause is compiled from `join_to_query`, which is a SELECT
172        # query returns at most one record. For it to be used in WHERE clause,
173        # it must be converted to a boolean value using EXISTS.
174        all_clauses = ('EXISTS (%s)' % initial_clause,)
175        if hasattr(join_to_query.query, 'extra_where'):
176            all_clauses += join_to_query.query.extra_where
177        info['where_clause'] = (
178                    ' AND '.join('(%s)' % clause for clause in all_clauses))
179        info['values'] = values
180        return info
181
182
183    def _info_for_many_to_many_join(self, m2m_field, join_to_query, alias,
184                                    m2m_is_on_this_model):
185        """
186        @param m2m_field: a Django field representing the M2M relationship.
187                It uses a pivot table with the following structure:
188                this model table <---> M2M pivot table <---> joined model table
189        @param join_to_query: the query over the related model that we're
190                joining to.
191        @param alias: alias of joined table
192        """
193        if m2m_is_on_this_model:
194            # referenced field on this model
195            lhs_id_field = self.model._meta.pk
196            # foreign key on the pivot table referencing lhs_id_field
197            m2m_lhs_column = m2m_field.m2m_column_name()
198            # foreign key on the pivot table referencing rhd_id_field
199            m2m_rhs_column = m2m_field.m2m_reverse_name()
200            # referenced field on related model
201            rhs_id_field = m2m_field.rel.get_related_field()
202        else:
203            lhs_id_field = m2m_field.rel.get_related_field()
204            m2m_lhs_column = m2m_field.m2m_reverse_name()
205            m2m_rhs_column = m2m_field.m2m_column_name()
206            rhs_id_field = join_to_query.model._meta.pk
207
208        info = {}
209        info['rhs_table'] = m2m_field.m2m_db_table()
210        info['rhs_column'] = m2m_lhs_column
211        info['lhs_column'] = lhs_id_field.column
212
213        # select the ID of related models relevant to this join.  we can only do
214        # a single join, so we need to gather this information up front and
215        # include it in the join condition.
216        rhs_ids = join_to_query.values_list(rhs_id_field.attname, flat=True)
217        assert len(rhs_ids) == 1, ('Many-to-many custom field joins can only '
218                                   'match a single related object.')
219        rhs_id = rhs_ids[0]
220
221        info['where_clause'] = '%s.%s = %s' % (_quote_name(alias),
222                                               _quote_name(m2m_rhs_column),
223                                               rhs_id)
224        info['values'] = ()
225        return info
226
227
228    def join_custom_field(self, query_set, join_to_query, alias,
229                          left_join=True):
230        """Join to a related model to create a custom field in the given query.
231
232        This method is used to construct a custom field on the given query based
233        on a many-valued relationsip.  join_to_query should be a simple query
234        (no joins) on the related model which returns at most one related row
235        per instance of this model.
236
237        For many-to-one relationships, the joined table contains the matching
238        row from the related model it one is related, NULL otherwise.
239
240        For many-to-many relationships, the joined table contains the matching
241        row if it's related, NULL otherwise.
242        """
243        relationship_type, field = self.determine_relationship(
244                join_to_query.model)
245
246        if relationship_type == self.MANY_TO_ONE:
247            info = self._info_for_many_to_one_join(field, join_to_query, alias)
248        elif relationship_type == self.M2M_ON_RELATED_MODEL:
249            info = self._info_for_many_to_many_join(
250                    m2m_field=field, join_to_query=join_to_query, alias=alias,
251                    m2m_is_on_this_model=False)
252        elif relationship_type ==self.M2M_ON_THIS_MODEL:
253            info = self._info_for_many_to_many_join(
254                    m2m_field=field, join_to_query=join_to_query, alias=alias,
255                    m2m_is_on_this_model=True)
256
257        return self.add_join(query_set, info['rhs_table'], info['rhs_column'],
258                             join_from_key=info['lhs_column'],
259                             join_condition=info['where_clause'],
260                             join_condition_values=info['values'],
261                             alias=alias,
262                             force_left_join=left_join)
263
264
265    def add_where(self, query_set, where, values=()):
266        query_set = query_set.all()
267        query_set.query.where.add(self._WhereClause(where, values),
268                                  django.db.models.sql.where.AND)
269        return query_set
270
271
272    def _get_quoted_field(self, table, field):
273        return _quote_name(table) + '.' + _quote_name(field)
274
275
276    def get_key_on_this_table(self, key_field=None):
277        if key_field is None:
278            # default to primary key
279            key_field = self.model._meta.pk.column
280        return self._get_quoted_field(self.model._meta.db_table, key_field)
281
282
283    def escape_user_sql(self, sql):
284        return sql.replace('%', '%%')
285
286
287    def _custom_select_query(self, query_set, selects):
288        """Execute a custom select query.
289
290        @param query_set: query set as returned by query_objects.
291        @param selects: Tables/Columns to select, e.g. tko_test_labels_list.id.
292
293        @returns: Result of the query as returned by cursor.fetchall().
294        """
295        compiler = query_set.query.get_compiler(using=query_set.db)
296        sql, params = compiler.as_sql()
297        from_ = sql[sql.find(' FROM'):]
298
299        if query_set.query.distinct:
300            distinct = 'DISTINCT '
301        else:
302            distinct = ''
303
304        sql_query = ('SELECT ' + distinct + ','.join(selects) + from_)
305        # Chose the connection that's responsible for this type of object
306        cursor = connections[query_set.db].cursor()
307        cursor.execute(sql_query, params)
308        return cursor.fetchall()
309
310
311    def _is_relation_to(self, field, model_class):
312        return field.rel and field.rel.to is model_class
313
314
315    MANY_TO_ONE = object()
316    M2M_ON_RELATED_MODEL = object()
317    M2M_ON_THIS_MODEL = object()
318
319    def determine_relationship(self, related_model):
320        """
321        Determine the relationship between this model and related_model.
322
323        related_model must have some sort of many-valued relationship to this
324        manager's model.
325        @returns (relationship_type, field), where relationship_type is one of
326                MANY_TO_ONE, M2M_ON_RELATED_MODEL, M2M_ON_THIS_MODEL, and field
327                is the Django field object for the relationship.
328        """
329        # look for a foreign key field on related_model relating to this model
330        for field in related_model._meta.fields:
331            if self._is_relation_to(field, self.model):
332                return self.MANY_TO_ONE, field
333
334        # look for an M2M field on related_model relating to this model
335        for field in related_model._meta.many_to_many:
336            if self._is_relation_to(field, self.model):
337                return self.M2M_ON_RELATED_MODEL, field
338
339        # maybe this model has the many-to-many field
340        for field in self.model._meta.many_to_many:
341            if self._is_relation_to(field, related_model):
342                return self.M2M_ON_THIS_MODEL, field
343
344        raise ValueError('%s has no relation to %s' %
345                         (related_model, self.model))
346
347
348    def _get_pivot_iterator(self, base_objects_by_id, related_model):
349        """
350        Determine the relationship between this model and related_model, and
351        return a pivot iterator.
352        @param base_objects_by_id: dict of instances of this model indexed by
353        their IDs
354        @returns a pivot iterator, which yields a tuple (base_object,
355        related_object) for each relationship between a base object and a
356        related object.  all base_object instances come from base_objects_by_id.
357        Note -- this depends on Django model internals.
358        """
359        relationship_type, field = self.determine_relationship(related_model)
360        if relationship_type == self.MANY_TO_ONE:
361            return self._many_to_one_pivot(base_objects_by_id,
362                                           related_model, field)
363        elif relationship_type == self.M2M_ON_RELATED_MODEL:
364            return self._many_to_many_pivot(
365                    base_objects_by_id, related_model, field.m2m_db_table(),
366                    field.m2m_reverse_name(), field.m2m_column_name())
367        else:
368            assert relationship_type == self.M2M_ON_THIS_MODEL
369            return self._many_to_many_pivot(
370                    base_objects_by_id, related_model, field.m2m_db_table(),
371                    field.m2m_column_name(), field.m2m_reverse_name())
372
373
374    def _many_to_one_pivot(self, base_objects_by_id, related_model,
375                           foreign_key_field):
376        """
377        @returns a pivot iterator - see _get_pivot_iterator()
378        """
379        filter_data = {foreign_key_field.name + '__pk__in':
380                       base_objects_by_id.keys()}
381        for related_object in related_model.objects.filter(**filter_data):
382            # lookup base object in the dict, rather than grabbing it from the
383            # related object.  we need to return instances from the dict, not
384            # fresh instances of the same models (and grabbing model instances
385            # from the related models incurs a DB query each time).
386            base_object_id = getattr(related_object, foreign_key_field.attname)
387            base_object = base_objects_by_id[base_object_id]
388            yield base_object, related_object
389
390
391    def _query_pivot_table(self, base_objects_by_id, pivot_table,
392                           pivot_from_field, pivot_to_field, related_model):
393        """
394        @param id_list list of IDs of self.model objects to include
395        @param pivot_table the name of the pivot table
396        @param pivot_from_field a field name on pivot_table referencing
397        self.model
398        @param pivot_to_field a field name on pivot_table referencing the
399        related model.
400        @param related_model the related model
401
402        @returns pivot list of IDs (base_id, related_id)
403        """
404        query = """
405        SELECT %(from_field)s, %(to_field)s
406        FROM %(table)s
407        WHERE %(from_field)s IN (%(id_list)s)
408        """ % dict(from_field=pivot_from_field,
409                   to_field=pivot_to_field,
410                   table=pivot_table,
411                   id_list=','.join(str(id_) for id_
412                                    in base_objects_by_id.iterkeys()))
413
414        # Chose the connection that's responsible for this type of object
415        # The databases for related_model and the current model will always
416        # be the same, related_model is just easier to obtain here because
417        # self is only a ExtendedManager, not the object.
418        cursor = connections[related_model.objects.db].cursor()
419        cursor.execute(query)
420        return cursor.fetchall()
421
422
423    def _many_to_many_pivot(self, base_objects_by_id, related_model,
424                            pivot_table, pivot_from_field, pivot_to_field):
425        """
426        @param pivot_table: see _query_pivot_table
427        @param pivot_from_field: see _query_pivot_table
428        @param pivot_to_field: see _query_pivot_table
429        @returns a pivot iterator - see _get_pivot_iterator()
430        """
431        id_pivot = self._query_pivot_table(base_objects_by_id, pivot_table,
432                                           pivot_from_field, pivot_to_field,
433                                           related_model)
434
435        all_related_ids = list(set(related_id for base_id, related_id
436                                   in id_pivot))
437        related_objects_by_id = related_model.objects.in_bulk(all_related_ids)
438
439        for base_id, related_id in id_pivot:
440            yield base_objects_by_id[base_id], related_objects_by_id[related_id]
441
442
443    def populate_relationships(self, base_objects, related_model,
444                               related_list_name):
445        """
446        For each instance of this model in base_objects, add a field named
447        related_list_name listing all the related objects of type related_model.
448        related_model must be in a many-to-one or many-to-many relationship with
449        this model.
450        @param base_objects - list of instances of this model
451        @param related_model - model class related to this model
452        @param related_list_name - attribute name in which to store the related
453        object list.
454        """
455        if not base_objects:
456            # if we don't bail early, we'll get a SQL error later
457            return
458
459        # The default maximum value of a host parameter number in SQLite is 999.
460        # Exceed this will get a DatabaseError later.
461        batch_size = 900
462        for i in xrange(0, len(base_objects), batch_size):
463            base_objects_batch = base_objects[i:i + batch_size]
464            base_objects_by_id = dict((base_object._get_pk_val(), base_object)
465                                      for base_object in base_objects_batch)
466            pivot_iterator = self._get_pivot_iterator(base_objects_by_id,
467                                                      related_model)
468
469            for base_object in base_objects_batch:
470                setattr(base_object, related_list_name, [])
471
472            for base_object, related_object in pivot_iterator:
473                getattr(base_object, related_list_name).append(related_object)
474
475
476class ModelWithInvalidQuerySet(dbmodels.query.QuerySet):
477    """
478    QuerySet that handles delete() properly for models with an "invalid" bit
479    """
480    def delete(self):
481        for model in self:
482            model.delete()
483
484
485class ModelWithInvalidManager(ExtendedManager):
486    """
487    Manager for objects with an "invalid" bit
488    """
489    def get_query_set(self):
490        return ModelWithInvalidQuerySet(self.model)
491
492
493class ValidObjectsManager(ModelWithInvalidManager):
494    """
495    Manager returning only objects with invalid=False.
496    """
497    def get_query_set(self):
498        queryset = super(ValidObjectsManager, self).get_query_set()
499        return queryset.filter(invalid=False)
500
501
502class ModelExtensions(rdb_model_extensions.ModelValidators):
503    """\
504    Mixin with convenience functions for models, built on top of
505    the model validators in rdb_model_extensions.
506    """
507    # TODO: at least some of these functions really belong in a custom
508    # Manager class
509
510
511    SERIALIZATION_LINKS_TO_FOLLOW = set()
512    """
513    To be able to send jobs and hosts to shards, it's necessary to find their
514    dependencies.
515    The most generic approach for this would be to traverse all relationships
516    to other objects recursively. This would list all objects that are related
517    in any way.
518    But this approach finds too many objects: If a host should be transferred,
519    all it's relationships would be traversed. This would find an acl group.
520    If then the acl group's relationships are traversed, the relationship
521    would be followed backwards and many other hosts would be found.
522
523    This mapping tells that algorithm which relations to follow explicitly.
524    """
525
526
527    SERIALIZATION_LINKS_TO_KEEP = set()
528    """This set stores foreign keys which we don't want to follow, but
529    still want to include in the serialized dictionary. For
530    example, we follow the relationship `Host.hostattribute_set`,
531    but we do not want to follow `HostAttributes.host_id` back to
532    to Host, which would otherwise lead to a circle. However, we still
533    like to serialize HostAttribute.`host_id`."""
534
535    SERIALIZATION_LOCAL_LINKS_TO_UPDATE = set()
536    """
537    On deserializion, if the object to persist already exists, local fields
538    will only be updated, if their name is in this set.
539    """
540
541
542    @classmethod
543    def convert_human_readable_values(cls, data, to_human_readable=False):
544        """\
545        Performs conversions on user-supplied field data, to make it
546        easier for users to pass human-readable data.
547
548        For all fields that have choice sets, convert their values
549        from human-readable strings to enum values, if necessary.  This
550        allows users to pass strings instead of the corresponding
551        integer values.
552
553        For all foreign key fields, call smart_get with the supplied
554        data.  This allows the user to pass either an ID value or
555        the name of the object as a string.
556
557        If to_human_readable=True, perform the inverse - i.e. convert
558        numeric values to human readable values.
559
560        This method modifies data in-place.
561        """
562        field_dict = cls.get_field_dict()
563        for field_name in data:
564            if field_name not in field_dict or data[field_name] is None:
565                continue
566            field_obj = field_dict[field_name]
567            # convert enum values
568            if field_obj.choices:
569                for choice_data in field_obj.choices:
570                    # choice_data is (value, name)
571                    if to_human_readable:
572                        from_val, to_val = choice_data
573                    else:
574                        to_val, from_val = choice_data
575                    if from_val == data[field_name]:
576                        data[field_name] = to_val
577                        break
578            # convert foreign key values
579            elif field_obj.rel:
580                dest_obj = field_obj.rel.to.smart_get(data[field_name],
581                                                      valid_only=False)
582                if to_human_readable:
583                    # parameterized_jobs do not have a name_field
584                    if (field_name != 'parameterized_job' and
585                        dest_obj.name_field is not None):
586                        data[field_name] = getattr(dest_obj,
587                                                   dest_obj.name_field)
588                else:
589                    data[field_name] = dest_obj
590
591
592
593
594    def _validate_unique(self):
595        """\
596        Validate that unique fields are unique.  Django manipulators do
597        this too, but they're a huge pain to use manually.  Trust me.
598        """
599        errors = {}
600        cls = type(self)
601        field_dict = self.get_field_dict()
602        manager = cls.get_valid_manager()
603        for field_name, field_obj in field_dict.iteritems():
604            if not field_obj.unique:
605                continue
606
607            value = getattr(self, field_name)
608            if value is None and field_obj.auto_created:
609                # don't bother checking autoincrement fields about to be
610                # generated
611                continue
612
613            existing_objs = manager.filter(**{field_name : value})
614            num_existing = existing_objs.count()
615
616            if num_existing == 0:
617                continue
618            if num_existing == 1 and existing_objs[0].id == self.id:
619                continue
620            errors[field_name] = (
621                'This value must be unique (%s)' % (value))
622        return errors
623
624
625    def _validate(self):
626        """
627        First coerces all fields on this instance to their proper Python types.
628        Then runs validation on every field. Returns a dictionary of
629        field_name -> error_list.
630
631        Based on validate() from django.db.models.Model in Django 0.96, which
632        was removed in Django 1.0. It should reappear in a later version. See:
633            http://code.djangoproject.com/ticket/6845
634        """
635        error_dict = {}
636        for f in self._meta.fields:
637            try:
638                python_value = f.to_python(
639                    getattr(self, f.attname, f.get_default()))
640            except django.core.exceptions.ValidationError, e:
641                error_dict[f.name] = str(e)
642                continue
643
644            if not f.blank and not python_value:
645                error_dict[f.name] = 'This field is required.'
646                continue
647
648            setattr(self, f.attname, python_value)
649
650        return error_dict
651
652
653    def do_validate(self):
654        errors = self._validate()
655        unique_errors = self._validate_unique()
656        for field_name, error in unique_errors.iteritems():
657            errors.setdefault(field_name, error)
658        if errors:
659            raise ValidationError(errors)
660
661
662    # actually (externally) useful methods follow
663
664    @classmethod
665    def add_object(cls, data={}, **kwargs):
666        """\
667        Returns a new object created with the given data (a dictionary
668        mapping field names to values). Merges any extra keyword args
669        into data.
670        """
671        data = dict(data)
672        data.update(kwargs)
673        data = cls.prepare_data_args(data)
674        cls.convert_human_readable_values(data)
675        data = cls.provide_default_values(data)
676
677        obj = cls(**data)
678        obj.do_validate()
679        obj.save()
680        return obj
681
682
683    def update_object(self, data={}, **kwargs):
684        """\
685        Updates the object with the given data (a dictionary mapping
686        field names to values).  Merges any extra keyword args into
687        data.
688        """
689        data = dict(data)
690        data.update(kwargs)
691        data = self.prepare_data_args(data)
692        self.convert_human_readable_values(data)
693        for field_name, value in data.iteritems():
694            setattr(self, field_name, value)
695        self.do_validate()
696        self.save()
697
698
699    # see query_objects()
700    _SPECIAL_FILTER_KEYS = ('query_start', 'query_limit', 'sort_by',
701                            'extra_args', 'extra_where', 'no_distinct')
702
703
704    @classmethod
705    def _extract_special_params(cls, filter_data):
706        """
707        @returns a tuple of dicts (special_params, regular_filters), where
708        special_params contains the parameters we handle specially and
709        regular_filters is the remaining data to be handled by Django.
710        """
711        regular_filters = dict(filter_data)
712        special_params = {}
713        for key in cls._SPECIAL_FILTER_KEYS:
714            if key in regular_filters:
715                special_params[key] = regular_filters.pop(key)
716        return special_params, regular_filters
717
718
719    @classmethod
720    def apply_presentation(cls, query, filter_data):
721        """
722        Apply presentation parameters -- sorting and paging -- to the given
723        query.
724        @returns new query with presentation applied
725        """
726        special_params, _ = cls._extract_special_params(filter_data)
727        sort_by = special_params.get('sort_by', None)
728        if sort_by:
729            assert isinstance(sort_by, list) or isinstance(sort_by, tuple)
730            query = query.extra(order_by=sort_by)
731
732        query_start = special_params.get('query_start', None)
733        query_limit = special_params.get('query_limit', None)
734        if query_start is not None:
735            if query_limit is None:
736                raise ValueError('Cannot pass query_start without query_limit')
737            # query_limit is passed as a page size
738            query_limit += query_start
739        return query[query_start:query_limit]
740
741
742    @classmethod
743    def query_objects(cls, filter_data, valid_only=True, initial_query=None,
744                      apply_presentation=True):
745        """\
746        Returns a QuerySet object for querying the given model_class
747        with the given filter_data.  Optional special arguments in
748        filter_data include:
749        -query_start: index of first return to return
750        -query_limit: maximum number of results to return
751        -sort_by: list of fields to sort on.  prefixing a '-' onto a
752         field name changes the sort to descending order.
753        -extra_args: keyword args to pass to query.extra() (see Django
754         DB layer documentation)
755        -extra_where: extra WHERE clause to append
756        -no_distinct: if True, a DISTINCT will not be added to the SELECT
757        """
758        special_params, regular_filters = cls._extract_special_params(
759                filter_data)
760
761        if initial_query is None:
762            if valid_only:
763                initial_query = cls.get_valid_manager()
764            else:
765                initial_query = cls.objects
766
767        query = initial_query.filter(**regular_filters)
768
769        use_distinct = not special_params.get('no_distinct', False)
770        if use_distinct:
771            query = query.distinct()
772
773        extra_args = special_params.get('extra_args', {})
774        extra_where = special_params.get('extra_where', None)
775        if extra_where:
776            # escape %'s
777            extra_where = cls.objects.escape_user_sql(extra_where)
778            extra_args.setdefault('where', []).append(extra_where)
779        if extra_args:
780            query = query.extra(**extra_args)
781            # TODO: Use readonly connection for these queries.
782            # This has been disabled, because it's not used anyway, as the
783            # configured readonly user is the same as the real user anyway.
784
785        if apply_presentation:
786            query = cls.apply_presentation(query, filter_data)
787
788        return query
789
790
791    @classmethod
792    def query_count(cls, filter_data, initial_query=None):
793        """\
794        Like query_objects, but retreive only the count of results.
795        """
796        filter_data.pop('query_start', None)
797        filter_data.pop('query_limit', None)
798        query = cls.query_objects(filter_data, initial_query=initial_query)
799        return query.count()
800
801
802    @classmethod
803    def clean_object_dicts(cls, field_dicts):
804        """\
805        Take a list of dicts corresponding to object (as returned by
806        query.values()) and clean the data to be more suitable for
807        returning to the user.
808        """
809        for field_dict in field_dicts:
810            cls.clean_foreign_keys(field_dict)
811            cls._convert_booleans(field_dict)
812            cls.convert_human_readable_values(field_dict,
813                                              to_human_readable=True)
814
815
816    @classmethod
817    def list_objects(cls, filter_data, initial_query=None):
818        """\
819        Like query_objects, but return a list of dictionaries.
820        """
821        query = cls.query_objects(filter_data, initial_query=initial_query)
822        extra_fields = query.query.extra_select.keys()
823        field_dicts = [model_object.get_object_dict(extra_fields=extra_fields)
824                       for model_object in query]
825        return field_dicts
826
827
828    @classmethod
829    def smart_get(cls, id_or_name, valid_only=True):
830        """\
831        smart_get(integer) -> get object by ID
832        smart_get(string) -> get object by name_field
833        """
834        if valid_only:
835            manager = cls.get_valid_manager()
836        else:
837            manager = cls.objects
838
839        if isinstance(id_or_name, (int, long)):
840            return manager.get(pk=id_or_name)
841        if isinstance(id_or_name, basestring) and hasattr(cls, 'name_field'):
842            return manager.get(**{cls.name_field : id_or_name})
843        raise ValueError(
844            'Invalid positional argument: %s (%s)' % (id_or_name,
845                                                      type(id_or_name)))
846
847
848    @classmethod
849    def smart_get_bulk(cls, id_or_name_list):
850        invalid_inputs = []
851        result_objects = []
852        for id_or_name in id_or_name_list:
853            try:
854                result_objects.append(cls.smart_get(id_or_name))
855            except cls.DoesNotExist:
856                invalid_inputs.append(id_or_name)
857        if invalid_inputs:
858            raise cls.DoesNotExist('The following %ss do not exist: %s'
859                                   % (cls.__name__.lower(),
860                                      ', '.join(invalid_inputs)))
861        return result_objects
862
863
864    def get_object_dict(self, extra_fields=None):
865        """\
866        Return a dictionary mapping fields to this object's values.  @param
867        extra_fields: list of extra attribute names to include, in addition to
868        the fields defined on this object.
869        """
870        fields = self.get_field_dict().keys()
871        if extra_fields:
872            fields += extra_fields
873        object_dict = dict((field_name, getattr(self, field_name))
874                           for field_name in fields)
875        self.clean_object_dicts([object_dict])
876        self._postprocess_object_dict(object_dict)
877        return object_dict
878
879
880    def _postprocess_object_dict(self, object_dict):
881        """For subclasses to override."""
882        pass
883
884
885    @classmethod
886    def get_valid_manager(cls):
887        return cls.objects
888
889
890    def _record_attributes(self, attributes):
891        """
892        See on_attribute_changed.
893        """
894        assert not isinstance(attributes, basestring)
895        self._recorded_attributes = dict((attribute, getattr(self, attribute))
896                                         for attribute in attributes)
897
898
899    def _check_for_updated_attributes(self):
900        """
901        See on_attribute_changed.
902        """
903        for attribute, original_value in self._recorded_attributes.iteritems():
904            new_value = getattr(self, attribute)
905            if original_value != new_value:
906                self.on_attribute_changed(attribute, original_value)
907        self._record_attributes(self._recorded_attributes.keys())
908
909
910    def on_attribute_changed(self, attribute, old_value):
911        """
912        Called whenever an attribute is updated.  To be overridden.
913
914        To use this method, you must:
915        * call _record_attributes() from __init__() (after making the super
916        call) with a list of attributes for which you want to be notified upon
917        change.
918        * call _check_for_updated_attributes() from save().
919        """
920        pass
921
922
923    def serialize(self, include_dependencies=True):
924        """Serializes the object with dependencies.
925
926        The variable SERIALIZATION_LINKS_TO_FOLLOW defines which dependencies
927        this function will serialize with the object.
928
929        @param include_dependencies: Whether or not to follow relations to
930                                     objects this object depends on.
931                                     This parameter is used when uploading
932                                     jobs from a shard to the main, as the
933                                     main already has all the dependent
934                                     objects.
935
936        @returns: Dictionary representation of the object.
937        """
938        serialized = {}
939        for field in self._meta.concrete_model._meta.local_fields:
940            if field.rel is None:
941                serialized[field.name] = field._get_val_from_obj(self)
942            elif field.name in self.SERIALIZATION_LINKS_TO_KEEP:
943                # attname will contain "_id" suffix for foreign keys,
944                # e.g. HostAttribute.host will be serialized as 'host_id'.
945                # Use it for easy deserialization.
946                serialized[field.attname] = field._get_val_from_obj(self)
947
948        if include_dependencies:
949            for link in self.SERIALIZATION_LINKS_TO_FOLLOW:
950                serialized[link] = self._serialize_relation(link)
951
952        return serialized
953
954
955    def _serialize_relation(self, link):
956        """Serializes dependent objects given the name of the relation.
957
958        @param link: Name of the relation to take objects from.
959
960        @returns For To-Many relationships a list of the serialized related
961            objects, for To-One relationships the serialized related object.
962        """
963        try:
964            attr = getattr(self, link)
965        except AttributeError:
966            # One-To-One relationships that point to None may raise this
967            return None
968
969        if attr is None:
970            return None
971        if hasattr(attr, 'all'):
972            return [obj.serialize() for obj in attr.all()]
973        return attr.serialize()
974
975
976    @classmethod
977    def _split_local_from_foreign_values(cls, data):
978        """This splits local from foreign values in a serialized object.
979
980        @param data: The serialized object.
981
982        @returns A tuple of two lists, both containing tuples in the form
983                 (link_name, link_value). The first list contains all links
984                 for local fields, the second one contains those for foreign
985                 fields/objects.
986        """
987        links_to_local_values, links_to_related_values = [], []
988        for link, value in data.iteritems():
989            if link in cls.SERIALIZATION_LINKS_TO_FOLLOW:
990                # It's a foreign key
991                links_to_related_values.append((link, value))
992            else:
993                # It's a local attribute or a foreign key
994                # we don't want to follow.
995                links_to_local_values.append((link, value))
996        return links_to_local_values, links_to_related_values
997
998
999    @classmethod
1000    def _filter_update_allowed_fields(cls, data):
1001        """Filters data and returns only files that updates are allowed on.
1002
1003        This is i.e. needed for syncing aborted bits from the main to shards.
1004
1005        Local links are only allowed to be updated, if they are in
1006        SERIALIZATION_LOCAL_LINKS_TO_UPDATE.
1007        Overwriting existing values is allowed in order to be able to sync i.e.
1008        the aborted bit from the main to a shard.
1009
1010        The allowlisting mechanism is in place to prevent overwriting local
1011        status: If all fields were overwritten, jobs would be completely be
1012        set back to their original (unstarted) state.
1013
1014        @param data: List with tuples of the form (link_name, link_value), as
1015                     returned by _split_local_from_foreign_values.
1016
1017        @returns List of the same format as data, but only containing data for
1018                 fields that updates are allowed on.
1019        """
1020        return [pair for pair in data
1021                if pair[0] in cls.SERIALIZATION_LOCAL_LINKS_TO_UPDATE]
1022
1023
1024    @classmethod
1025    def delete_matching_record(cls, **filter_args):
1026        """Delete records matching the filter.
1027
1028        @param filter_args: Arguments for the django filter
1029                used to locate the record to delete.
1030        """
1031        try:
1032            existing_record = cls.objects.get(**filter_args)
1033        except cls.DoesNotExist:
1034            return
1035        existing_record.delete()
1036
1037
1038    def _deserialize_local(self, data):
1039        """Set local attributes from a list of tuples.
1040
1041        @param data: List of tuples like returned by
1042                     _split_local_from_foreign_values.
1043        """
1044        if not data:
1045            return
1046
1047        for link, value in data:
1048            setattr(self, link, value)
1049        # Overwridden save() methods are prone to errors, so don't execute them.
1050        # This is because:
1051        # - the overwritten methods depend on ACL groups that don't yet exist
1052        #   and don't handle errors
1053        # - the overwritten methods think this object already exists in the db
1054        #   because the id is already set
1055        super(type(self), self).save()
1056
1057
1058    def _deserialize_relations(self, data):
1059        """Set foreign attributes from a list of tuples.
1060
1061        This deserialized the related objects using their own deserialize()
1062        function and then sets the relation.
1063
1064        @param data: List of tuples like returned by
1065                     _split_local_from_foreign_values.
1066        """
1067        for link, value in data:
1068            self._deserialize_relation(link, value)
1069        # See comment in _deserialize_local
1070        super(type(self), self).save()
1071
1072
1073    @classmethod
1074    def get_record(cls, data):
1075        """Retrieve a record with the data in the given input arg.
1076
1077        @param data: A dictionary containing the information to use in a query
1078                for data. If child models have different constraints of
1079                uniqueness they should override this model.
1080
1081        @return: An object with matching data.
1082
1083        @raises DoesNotExist: If a record with the given data doesn't exist.
1084        """
1085        return cls.objects.get(id=data['id'])
1086
1087
1088    @classmethod
1089    def deserialize(cls, data):
1090        """Recursively deserializes and saves an object with it's dependencies.
1091
1092        This takes the result of the serialize method and creates objects
1093        in the database that are just like the original.
1094
1095        If an object of the same type with the same id already exists, it's
1096        local values will be left untouched, unless they are explicitly
1097        allowlisted in SERIALIZATION_LOCAL_LINKS_TO_UPDATE.
1098
1099        Deserialize will always recursively propagate to all related objects
1100        present in data though.
1101        I.e. this is necessary to add users to an already existing acl-group.
1102
1103        @param data: Representation of an object and its dependencies, as
1104                     returned by serialize.
1105
1106        @returns: The object represented by data if it didn't exist before,
1107                  otherwise the object that existed before and has the same type
1108                  and id as the one described by data.
1109        """
1110        if data is None:
1111            return None
1112
1113        local, related = cls._split_local_from_foreign_values(data)
1114        try:
1115            instance = cls.get_record(data)
1116            local = cls._filter_update_allowed_fields(local)
1117        except cls.DoesNotExist:
1118            instance = cls()
1119
1120        instance._deserialize_local(local)
1121        instance._deserialize_relations(related)
1122
1123        return instance
1124
1125
1126    def sanity_check_update_from_shard(self, shard, updated_serialized,
1127                                       *args, **kwargs):
1128        """Check if an update sent from a shard is legitimate.
1129
1130        @raises error.UnallowedRecordsSentToMain if an update is not
1131                legitimate.
1132        """
1133        raise NotImplementedError(
1134            'sanity_check_update_from_shard must be implemented by subclass %s '
1135            'for type %s' % type(self))
1136
1137
1138    @transaction.commit_on_success
1139    def update_from_serialized(self, serialized):
1140        """Updates local fields of an existing object from a serialized form.
1141
1142        This is different than the normal deserialize() in the way that it
1143        does update local values, which deserialize doesn't, but doesn't
1144        recursively propagate to related objects, which deserialize() does.
1145
1146        The use case of this function is to update job records on the main
1147        after the jobs have been executed on a shard, as the main is not
1148        interested in updates for users, labels, specialtasks, etc.
1149
1150        @param serialized: Representation of an object and its dependencies, as
1151                           returned by serialize.
1152
1153        @raises ValueError: if serialized contains related objects, i.e. not
1154                            only local fields.
1155        """
1156        local, related = (
1157            self._split_local_from_foreign_values(serialized))
1158        if related:
1159            raise ValueError('Serialized must not contain foreign '
1160                             'objects: %s' % related)
1161
1162        self._deserialize_local(local)
1163
1164
1165    def custom_deserialize_relation(self, link, data):
1166        """Allows overriding the deserialization behaviour by subclasses."""
1167        raise NotImplementedError(
1168            'custom_deserialize_relation must be implemented by subclass %s '
1169            'for relation %s' % (type(self), link))
1170
1171
1172    def _deserialize_relation(self, link, data):
1173        """Deserializes related objects and sets references on this object.
1174
1175        Relations that point to a list of objects are handled automatically.
1176        For many-to-one or one-to-one relations custom_deserialize_relation
1177        must be overridden by the subclass.
1178
1179        Related objects are deserialized using their deserialize() method.
1180        Thereby they and their dependencies are created if they don't exist
1181        and saved to the database.
1182
1183        @param link: Name of the relation.
1184        @param data: Serialized representation of the related object(s).
1185                     This means a list of dictionaries for to-many relations,
1186                     just a dictionary for to-one relations.
1187        """
1188        field = getattr(self, link)
1189
1190        if field and hasattr(field, 'all'):
1191            self._deserialize_2m_relation(link, data, field.model)
1192        else:
1193            self.custom_deserialize_relation(link, data)
1194
1195
1196    def _deserialize_2m_relation(self, link, data, related_class):
1197        """Deserialize related objects for one to-many relationship.
1198
1199        @param link: Name of the relation.
1200        @param data: Serialized representation of the related objects.
1201                     This is a list with of dictionaries.
1202        @param related_class: A class representing a django model, with which
1203                              this class has a one-to-many relationship.
1204        """
1205        relation_set = getattr(self, link)
1206        if related_class == self.get_attribute_model():
1207            # When deserializing a model together with
1208            # its attributes, clear all the exising attributes to ensure
1209            # db consistency. Note 'update' won't be sufficient, as we also
1210            # want to remove any attributes that no longer exist in |data|.
1211            #
1212            # core_filters is a dictionary of filters, defines how
1213            # RelatedMangager would query for the 1-to-many relationship. E.g.
1214            # Host.objects.get(
1215            #     id=20).hostattribute_set.core_filters = {host_id:20}
1216            # We use it to delete objects related to the current object.
1217            related_class.objects.filter(**relation_set.core_filters).delete()
1218        for serialized in data:
1219            relation_set.add(related_class.deserialize(serialized))
1220
1221
1222    @classmethod
1223    def get_attribute_model(cls):
1224        """Return the attribute model.
1225
1226        Subclass with attribute-like model should override this to
1227        return the attribute model class. This method will be
1228        called by _deserialize_2m_relation to determine whether
1229        to clear the one-to-many relations first on deserialization of object.
1230        """
1231        return None
1232
1233
1234class ModelWithInvalid(ModelExtensions):
1235    """
1236    Overrides model methods save() and delete() to support invalidation in
1237    place of actual deletion.  Subclasses must have a boolean "invalid"
1238    field.
1239    """
1240
1241    def save(self, *args, **kwargs):
1242        first_time = (self.id is None)
1243        if first_time:
1244            # see if this object was previously added and invalidated
1245            my_name = getattr(self, self.name_field)
1246            filters = {self.name_field : my_name, 'invalid' : True}
1247            try:
1248                old_object = self.__class__.objects.get(**filters)
1249                self.resurrect_object(old_object)
1250            except self.DoesNotExist:
1251                # no existing object
1252                pass
1253
1254        super(ModelWithInvalid, self).save(*args, **kwargs)
1255
1256
1257    def resurrect_object(self, old_object):
1258        """
1259        Called when self is about to be saved for the first time and is actually
1260        "undeleting" a previously deleted object.  Can be overridden by
1261        subclasses to copy data as desired from the deleted entry (but this
1262        superclass implementation must normally be called).
1263        """
1264        self.id = old_object.id
1265
1266
1267    def clean_object(self):
1268        """
1269        This method is called when an object is marked invalid.
1270        Subclasses should override this to clean up relationships that
1271        should no longer exist if the object were deleted.
1272        """
1273        pass
1274
1275
1276    def delete(self):
1277        self.invalid = self.invalid
1278        assert not self.invalid
1279        self.invalid = True
1280        self.save()
1281        self.clean_object()
1282
1283
1284    @classmethod
1285    def get_valid_manager(cls):
1286        return cls.valid_objects
1287
1288
1289    class Manipulator(object):
1290        """
1291        Force default manipulators to look only at valid objects -
1292        otherwise they will match against invalid objects when checking
1293        uniqueness.
1294        """
1295        @classmethod
1296        def _prepare(cls, model):
1297            super(ModelWithInvalid.Manipulator, cls)._prepare(model)
1298            cls.manager = model.valid_objects
1299
1300
1301class ModelWithAttributes(object):
1302    """
1303    Mixin class for models that have an attribute model associated with them.
1304    The attribute model is assumed to have its value field named "value".
1305    """
1306
1307    def _get_attribute_model_and_args(self, attribute):
1308        """
1309        Subclasses should override this to return a tuple (attribute_model,
1310        keyword_args), where attribute_model is a model class and keyword_args
1311        is a dict of args to pass to attribute_model.objects.get() to get an
1312        instance of the given attribute on this object.
1313        """
1314        raise NotImplementedError
1315
1316
1317    def _is_replaced_by_static_attribute(self, attribute):
1318        """
1319        Subclasses could override this to indicate whether it has static
1320        attributes.
1321        """
1322        return False
1323
1324
1325    def set_attribute(self, attribute, value):
1326        if self._is_replaced_by_static_attribute(attribute):
1327            raise error.UnmodifiableAttributeException(
1328                    'Failed to set attribute "%s" for host "%s" since it '
1329                    'is static. Use go/chromeos-skylab-inventory-tools to '
1330                    'modify this attribute.' % (attribute, self.hostname))
1331
1332        attribute_model, get_args = self._get_attribute_model_and_args(
1333            attribute)
1334        attribute_object, _ = attribute_model.objects.get_or_create(**get_args)
1335        attribute_object.value = value
1336        attribute_object.save()
1337
1338
1339    def delete_attribute(self, attribute):
1340        if self._is_replaced_by_static_attribute(attribute):
1341            raise error.UnmodifiableAttributeException(
1342                    'Failed to delete attribute "%s" for host "%s" since it '
1343                    'is static. Use go/chromeos-skylab-inventory-tools to '
1344                    'modify this attribute.' % (attribute, self.hostname))
1345
1346        attribute_model, get_args = self._get_attribute_model_and_args(
1347            attribute)
1348        try:
1349            attribute_model.objects.get(**get_args).delete()
1350        except attribute_model.DoesNotExist:
1351            pass
1352
1353
1354    def set_or_delete_attribute(self, attribute, value):
1355        if value is None:
1356            self.delete_attribute(attribute)
1357        else:
1358            self.set_attribute(attribute, value)
1359
1360
1361class ModelWithHashManager(dbmodels.Manager):
1362    """Manager for use with the ModelWithHash abstract model class"""
1363
1364    def create(self, **kwargs):
1365        raise Exception('ModelWithHash manager should use get_or_create() '
1366                        'instead of create()')
1367
1368
1369    def get_or_create(self, **kwargs):
1370        kwargs['the_hash'] = self.model._compute_hash(**kwargs)
1371        return super(ModelWithHashManager, self).get_or_create(**kwargs)
1372
1373
1374class ModelWithHash(dbmodels.Model):
1375    """Superclass with methods for dealing with a hash column"""
1376
1377    the_hash = dbmodels.CharField(max_length=40, unique=True)
1378
1379    objects = ModelWithHashManager()
1380
1381    class Meta:
1382        abstract = True
1383
1384
1385    @classmethod
1386    def _compute_hash(cls, **kwargs):
1387        raise NotImplementedError('Subclasses must override _compute_hash()')
1388
1389
1390    def save(self, force_insert=False, **kwargs):
1391        """Prevents saving the model in most cases
1392
1393        We want these models to be immutable, so the generic save() operation
1394        will not work. These models should be instantiated through their the
1395        model.objects.get_or_create() method instead.
1396
1397        The exception is that save(force_insert=True) will be allowed, since
1398        that creates a new row. However, the preferred way to make instances of
1399        these models is through the get_or_create() method.
1400        """
1401        if not force_insert:
1402            # Allow a forced insert to happen; if it's a duplicate, the unique
1403            # constraint will catch it later anyways
1404            raise Exception('ModelWithHash is immutable')
1405        super(ModelWithHash, self).save(force_insert=force_insert, **kwargs)
1406