• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1"""
2Extensions to Django's model logic.
3"""
4
5import django.core.exceptions
6import django.db.models.sql.where
7import six
8from autotest_lib.client.common_lib import error
9from autotest_lib.frontend.afe import rdb_model_extensions
10from django.db import connection, connections
11from django.db import models as dbmodels
12from django.db import transaction
13from django.db.models.sql import query
14
15
16class ValidationError(django.core.exceptions.ValidationError):
17    """\
18    Data validation error in adding or updating an object. The associated
19    value is a dictionary mapping field names to error strings.
20    """
21
22def _quote_name(name):
23    """Shorthand for connection.ops.quote_name()."""
24    return connection.ops.quote_name(name)
25
26
27class LeasedHostManager(dbmodels.Manager):
28    """Query manager for unleased, unlocked hosts.
29    """
30    def get_query_set(self):
31        return (super(LeasedHostManager, self).get_query_set().filter(
32                leased=0, locked=0))
33
34
35class ExtendedManager(dbmodels.Manager):
36    """\
37    Extended manager supporting subquery filtering.
38    """
39
40    class CustomQuery(query.Query):
41        """A custom query"""
42
43        def __init__(self, *args, **kwargs):
44            super(ExtendedManager.CustomQuery, self).__init__(*args, **kwargs)
45            self._custom_joins = []
46
47
48        def clone(self, klass=None, **kwargs):
49            """Clones the query and returns the clone."""
50            obj = super(ExtendedManager.CustomQuery, self).clone(klass)
51            obj._custom_joins = list(self._custom_joins)
52            return obj
53
54
55        def combine(self, rhs, connector):
56            """Combines query with another query."""
57            super(ExtendedManager.CustomQuery, self).combine(rhs, connector)
58            if hasattr(rhs, '_custom_joins'):
59                self._custom_joins.extend(rhs._custom_joins)
60
61
62        def add_custom_join(self, table, condition, join_type,
63                            condition_values=(), alias=None):
64            """Adds a custom join to the query."""
65            if alias is None:
66                alias = table
67            join_dict = dict(table=table,
68                             condition=condition,
69                             condition_values=condition_values,
70                             join_type=join_type,
71                             alias=alias)
72            self._custom_joins.append(join_dict)
73
74
75        @classmethod
76        def convert_query(self, query_set):
77            """
78            Convert the query set's "query" attribute to a CustomQuery.
79            """
80            # Make a copy of the query set
81            query_set = query_set.all()
82            query_set.query = query_set.query.clone(
83                    klass=ExtendedManager.CustomQuery,
84                    _custom_joins=[])
85            return query_set
86
87
88    class _WhereClause(object):
89        """Object allowing us to inject arbitrary SQL into Django queries.
90
91        By using this instead of extra(where=...), we can still freely combine
92        queries with & and |.
93        """
94        def __init__(self, clause, values=()):
95            self._clause = clause
96            self._values = values
97
98
99        def as_sql(self, qn=None, connection=None):
100            """Converts the clause to SQL and returns it."""
101            return self._clause, self._values
102
103
104        def relabel_aliases(self, change_map):
105            """Does nothing."""
106            return
107
108
109    def add_join(self, query_set, join_table, join_key, join_condition='',
110                 join_condition_values=(), join_from_key=None, alias=None,
111                 suffix='', exclude=False, force_left_join=False):
112        """Add a join to query_set.
113
114        Join looks like this:
115                (INNER|LEFT) JOIN <join_table> AS <alias>
116                    ON (<this table>.<join_from_key> = <join_table>.<join_key>
117                        and <join_condition>)
118
119        @param join_table table to join to
120        @param join_key field referencing back to this model to use for the join
121        @param join_condition extra condition for the ON clause of the join
122        @param join_condition_values values to substitute into join_condition
123        @param join_from_key column on this model to join from.
124        @param alias alias to use for for join
125        @param suffix suffix to add to join_table for the join alias, if no
126                alias is provided
127        @param exclude if true, exclude rows that match this join (will use a
128        LEFT OUTER JOIN and an appropriate WHERE condition)
129        @param force_left_join - if true, a LEFT OUTER JOIN will be used
130        instead of an INNER JOIN regardless of other options
131        """
132        join_from_table = query_set.model._meta.db_table
133        if join_from_key is None:
134            join_from_key = self.model._meta.pk.name
135        if alias is None:
136            alias = join_table + suffix
137        full_join_key = _quote_name(alias) + '.' + _quote_name(join_key)
138        full_join_condition = '%s = %s.%s' % (full_join_key,
139                                              _quote_name(join_from_table),
140                                              _quote_name(join_from_key))
141        if join_condition:
142            full_join_condition += ' AND (' + join_condition + ')'
143        if exclude or force_left_join:
144            join_type = query_set.query.LOUTER
145        else:
146            join_type = query_set.query.INNER
147
148        query_set = self.CustomQuery.convert_query(query_set)
149        query_set.query.add_custom_join(join_table,
150                                        full_join_condition,
151                                        join_type,
152                                        condition_values=join_condition_values,
153                                        alias=alias)
154
155        if exclude:
156            query_set = query_set.extra(where=[full_join_key + ' IS NULL'])
157
158        return query_set
159
160
161    def _info_for_many_to_one_join(self, field, join_to_query, alias):
162        """
163        @param field: the ForeignKey field on the related model
164        @param join_to_query: the query over the related model that we're
165                joining to
166        @param alias: alias of joined table
167        """
168        info = {}
169        rhs_table = join_to_query.model._meta.db_table
170        info['rhs_table'] = rhs_table
171        info['rhs_column'] = field.column
172        info['lhs_column'] = field.rel.get_related_field().column
173        rhs_where = join_to_query.query.where
174        rhs_where.relabel_aliases({rhs_table: alias})
175        compiler = join_to_query.query.get_compiler(using=join_to_query.db)
176        initial_clause, values = compiler.as_sql()
177        # initial_clause is compiled from `join_to_query`, which is a SELECT
178        # query returns at most one record. For it to be used in WHERE clause,
179        # it must be converted to a boolean value using EXISTS.
180        all_clauses = ('EXISTS (%s)' % initial_clause,)
181        if hasattr(join_to_query.query, 'extra_where'):
182            all_clauses += join_to_query.query.extra_where
183        info['where_clause'] = (
184                    ' AND '.join('(%s)' % clause for clause in all_clauses))
185        info['values'] = values
186        return info
187
188
189    def _info_for_many_to_many_join(self, m2m_field, join_to_query, alias,
190                                    m2m_is_on_this_model):
191        """
192        @param m2m_field: a Django field representing the M2M relationship.
193                It uses a pivot table with the following structure:
194                this model table <---> M2M pivot table <---> joined model table
195        @param join_to_query: the query over the related model that we're
196                joining to.
197        @param alias: alias of joined table
198        """
199        if m2m_is_on_this_model:
200            # referenced field on this model
201            lhs_id_field = self.model._meta.pk
202            # foreign key on the pivot table referencing lhs_id_field
203            m2m_lhs_column = m2m_field.m2m_column_name()
204            # foreign key on the pivot table referencing rhd_id_field
205            m2m_rhs_column = m2m_field.m2m_reverse_name()
206            # referenced field on related model
207            rhs_id_field = m2m_field.rel.get_related_field()
208        else:
209            lhs_id_field = m2m_field.rel.get_related_field()
210            m2m_lhs_column = m2m_field.m2m_reverse_name()
211            m2m_rhs_column = m2m_field.m2m_column_name()
212            rhs_id_field = join_to_query.model._meta.pk
213
214        info = {}
215        info['rhs_table'] = m2m_field.m2m_db_table()
216        info['rhs_column'] = m2m_lhs_column
217        info['lhs_column'] = lhs_id_field.column
218
219        # select the ID of related models relevant to this join.  we can only do
220        # a single join, so we need to gather this information up front and
221        # include it in the join condition.
222        rhs_ids = join_to_query.values_list(rhs_id_field.attname, flat=True)
223        assert len(rhs_ids) == 1, ('Many-to-many custom field joins can only '
224                                   'match a single related object.')
225        rhs_id = rhs_ids[0]
226
227        info['where_clause'] = '%s.%s = %s' % (_quote_name(alias),
228                                               _quote_name(m2m_rhs_column),
229                                               rhs_id)
230        info['values'] = ()
231        return info
232
233
234    def join_custom_field(self, query_set, join_to_query, alias,
235                          left_join=True):
236        """Join to a related model to create a custom field in the given query.
237
238        This method is used to construct a custom field on the given query based
239        on a many-valued relationsip.  join_to_query should be a simple query
240        (no joins) on the related model which returns at most one related row
241        per instance of this model.
242
243        For many-to-one relationships, the joined table contains the matching
244        row from the related model it one is related, NULL otherwise.
245
246        For many-to-many relationships, the joined table contains the matching
247        row if it's related, NULL otherwise.
248        """
249        relationship_type, field = self.determine_relationship(
250                join_to_query.model)
251
252        if relationship_type == self.MANY_TO_ONE:
253            info = self._info_for_many_to_one_join(field, join_to_query, alias)
254        elif relationship_type == self.M2M_ON_RELATED_MODEL:
255            info = self._info_for_many_to_many_join(
256                    m2m_field=field, join_to_query=join_to_query, alias=alias,
257                    m2m_is_on_this_model=False)
258        elif relationship_type ==self.M2M_ON_THIS_MODEL:
259            info = self._info_for_many_to_many_join(
260                    m2m_field=field, join_to_query=join_to_query, alias=alias,
261                    m2m_is_on_this_model=True)
262
263        return self.add_join(query_set, info['rhs_table'], info['rhs_column'],
264                             join_from_key=info['lhs_column'],
265                             join_condition=info['where_clause'],
266                             join_condition_values=info['values'],
267                             alias=alias,
268                             force_left_join=left_join)
269
270
271    def add_where(self, query_set, where, values=()):
272        """Adds a where clause to the query_set."""
273        query_set = query_set.all()
274        query_set.query.where.add(self._WhereClause(where, values),
275                                  django.db.models.sql.where.AND)
276        return query_set
277
278
279    def _get_quoted_field(self, table, field):
280        return _quote_name(table) + '.' + _quote_name(field)
281
282
283    def get_key_on_this_table(self, key_field=None):
284        if key_field is None:
285            # default to primary key
286            key_field = self.model._meta.pk.column
287        return self._get_quoted_field(self.model._meta.db_table, key_field)
288
289
290    def escape_user_sql(self, sql):
291        """Escapes % in sql."""
292        return sql.replace('%', '%%')
293
294
295    def _custom_select_query(self, query_set, selects):
296        """Execute a custom select query.
297
298        @param query_set: query set as returned by query_objects.
299        @param selects: Tables/Columns to select, e.g. tko_test_labels_list.id.
300
301        @returns: Result of the query as returned by cursor.fetchall().
302        """
303        compiler = query_set.query.get_compiler(using=query_set.db)
304        sql, params = compiler.as_sql()
305        from_ = sql[sql.find(' FROM'):]
306
307        if query_set.query.distinct:
308            distinct = 'DISTINCT '
309        else:
310            distinct = ''
311
312        sql_query = ('SELECT ' + distinct + ','.join(selects) + from_)
313        # Chose the connection that's responsible for this type of object
314        cursor = connections[query_set.db].cursor()
315        cursor.execute(sql_query, params)
316        return cursor.fetchall()
317
318
319    def _is_relation_to(self, field, model_class):
320        return field.rel and field.rel.to is model_class
321
322
323    MANY_TO_ONE = object()
324    M2M_ON_RELATED_MODEL = object()
325    M2M_ON_THIS_MODEL = object()
326
327    def determine_relationship(self, related_model):
328        """
329        Determine the relationship between this model and related_model.
330
331        related_model must have some sort of many-valued relationship to this
332        manager's model.
333        @returns (relationship_type, field), where relationship_type is one of
334                MANY_TO_ONE, M2M_ON_RELATED_MODEL, M2M_ON_THIS_MODEL, and field
335                is the Django field object for the relationship.
336        """
337        # look for a foreign key field on related_model relating to this model
338        for field in related_model._meta.fields:
339            if self._is_relation_to(field, self.model):
340                return self.MANY_TO_ONE, field
341
342        # look for an M2M field on related_model relating to this model
343        for field in related_model._meta.many_to_many:
344            if self._is_relation_to(field, self.model):
345                return self.M2M_ON_RELATED_MODEL, field
346
347        # maybe this model has the many-to-many field
348        for field in self.model._meta.many_to_many:
349            if self._is_relation_to(field, related_model):
350                return self.M2M_ON_THIS_MODEL, field
351
352        raise ValueError('%s has no relation to %s' %
353                         (related_model, self.model))
354
355
356    def _get_pivot_iterator(self, base_objects_by_id, related_model):
357        """
358        Determine the relationship between this model and related_model, and
359        return a pivot iterator.
360        @param base_objects_by_id: dict of instances of this model indexed by
361        their IDs
362        @returns a pivot iterator, which yields a tuple (base_object,
363        related_object) for each relationship between a base object and a
364        related object.  all base_object instances come from base_objects_by_id.
365        Note -- this depends on Django model internals.
366        """
367        relationship_type, field = self.determine_relationship(related_model)
368        if relationship_type == self.MANY_TO_ONE:
369            return self._many_to_one_pivot(base_objects_by_id,
370                                           related_model, field)
371        elif relationship_type == self.M2M_ON_RELATED_MODEL:
372            return self._many_to_many_pivot(
373                    base_objects_by_id, related_model, field.m2m_db_table(),
374                    field.m2m_reverse_name(), field.m2m_column_name())
375        else:
376            assert relationship_type == self.M2M_ON_THIS_MODEL
377            return self._many_to_many_pivot(
378                    base_objects_by_id, related_model, field.m2m_db_table(),
379                    field.m2m_column_name(), field.m2m_reverse_name())
380
381
382    def _many_to_one_pivot(self, base_objects_by_id, related_model,
383                           foreign_key_field):
384        """
385        @returns a pivot iterator - see _get_pivot_iterator()
386        """
387        filter_data = {foreign_key_field.name + '__pk__in':
388                       base_objects_by_id.keys()}
389        for related_object in related_model.objects.filter(**filter_data):
390            # lookup base object in the dict, rather than grabbing it from the
391            # related object.  we need to return instances from the dict, not
392            # fresh instances of the same models (and grabbing model instances
393            # from the related models incurs a DB query each time).
394            base_object_id = getattr(related_object, foreign_key_field.attname)
395            base_object = base_objects_by_id[base_object_id]
396            yield base_object, related_object
397
398
399    def _query_pivot_table(self, base_objects_by_id, pivot_table,
400                           pivot_from_field, pivot_to_field, related_model):
401        """
402        @param id_list list of IDs of self.model objects to include
403        @param pivot_table the name of the pivot table
404        @param pivot_from_field a field name on pivot_table referencing
405        self.model
406        @param pivot_to_field a field name on pivot_table referencing the
407        related model.
408        @param related_model the related model
409
410        @returns pivot list of IDs (base_id, related_id)
411        """
412        query = """
413        SELECT %(from_field)s, %(to_field)s
414        FROM %(table)s
415        WHERE %(from_field)s IN (%(id_list)s)
416        """ % dict(from_field=pivot_from_field,
417                   to_field=pivot_to_field,
418                   table=pivot_table,
419                   id_list=','.join(
420                           str(id_)
421                           for id_ in six.iterkeys(base_objects_by_id)))
422
423        # Chose the connection that's responsible for this type of object
424        # The databases for related_model and the current model will always
425        # be the same, related_model is just easier to obtain here because
426        # self is only a ExtendedManager, not the object.
427        cursor = connections[related_model.objects.db].cursor()
428        cursor.execute(query)
429        return cursor.fetchall()
430
431
432    def _many_to_many_pivot(self, base_objects_by_id, related_model,
433                            pivot_table, pivot_from_field, pivot_to_field):
434        """
435        @param pivot_table: see _query_pivot_table
436        @param pivot_from_field: see _query_pivot_table
437        @param pivot_to_field: see _query_pivot_table
438        @returns a pivot iterator - see _get_pivot_iterator()
439        """
440        id_pivot = self._query_pivot_table(base_objects_by_id, pivot_table,
441                                           pivot_from_field, pivot_to_field,
442                                           related_model)
443
444        all_related_ids = list(set(related_id for base_id, related_id
445                                   in id_pivot))
446        related_objects_by_id = related_model.objects.in_bulk(all_related_ids)
447
448        for base_id, related_id in id_pivot:
449            yield base_objects_by_id[base_id], related_objects_by_id[related_id]
450
451
452    def populate_relationships(self, base_objects, related_model,
453                               related_list_name):
454        """
455        For each instance of this model in base_objects, add a field named
456        related_list_name listing all the related objects of type related_model.
457        related_model must be in a many-to-one or many-to-many relationship with
458        this model.
459        @param base_objects - list of instances of this model
460        @param related_model - model class related to this model
461        @param related_list_name - attribute name in which to store the related
462        object list.
463        """
464        if not base_objects:
465            # if we don't bail early, we'll get a SQL error later
466            return
467
468        # The default maximum value of a host parameter number in SQLite is 999.
469        # Exceed this will get a DatabaseError later.
470        batch_size = 900
471        for i in range(0, len(base_objects), batch_size):
472            base_objects_batch = base_objects[i:i + batch_size]
473            base_objects_by_id = dict((base_object._get_pk_val(), base_object)
474                                      for base_object in base_objects_batch)
475            pivot_iterator = self._get_pivot_iterator(base_objects_by_id,
476                                                      related_model)
477
478            for base_object in base_objects_batch:
479                setattr(base_object, related_list_name, [])
480
481            for base_object, related_object in pivot_iterator:
482                getattr(base_object, related_list_name).append(related_object)
483
484
485class ModelWithInvalidQuerySet(dbmodels.query.QuerySet):
486    """
487    QuerySet that handles delete() properly for models with an "invalid" bit
488    """
489    def delete(self):
490        """Deletes the QuerySet."""
491        for model in self:
492            model.delete()
493
494
495class ModelWithInvalidManager(ExtendedManager):
496    """
497    Manager for objects with an "invalid" bit
498    """
499    def get_query_set(self):
500        return ModelWithInvalidQuerySet(self.model)
501
502
503class ValidObjectsManager(ModelWithInvalidManager):
504    """
505    Manager returning only objects with invalid=False.
506    """
507    def get_query_set(self):
508        queryset = super(ValidObjectsManager, self).get_query_set()
509        return queryset.filter(invalid=False)
510
511
512class ModelExtensions(rdb_model_extensions.ModelValidators):
513    """\
514    Mixin with convenience functions for models, built on top of
515    the model validators in rdb_model_extensions.
516    """
517    # TODO: at least some of these functions really belong in a custom
518    # Manager class
519
520
521    SERIALIZATION_LINKS_TO_FOLLOW = set()
522    """
523    To be able to send jobs and hosts to shards, it's necessary to find their
524    dependencies.
525    The most generic approach for this would be to traverse all relationships
526    to other objects recursively. This would list all objects that are related
527    in any way.
528    But this approach finds too many objects: If a host should be transferred,
529    all it's relationships would be traversed. This would find an acl group.
530    If then the acl group's relationships are traversed, the relationship
531    would be followed backwards and many other hosts would be found.
532
533    This mapping tells that algorithm which relations to follow explicitly.
534    """
535
536
537    SERIALIZATION_LINKS_TO_KEEP = set()
538    """This set stores foreign keys which we don't want to follow, but
539    still want to include in the serialized dictionary. For
540    example, we follow the relationship `Host.hostattribute_set`,
541    but we do not want to follow `HostAttributes.host_id` back to
542    to Host, which would otherwise lead to a circle. However, we still
543    like to serialize HostAttribute.`host_id`."""
544
545    SERIALIZATION_LOCAL_LINKS_TO_UPDATE = set()
546    """
547    On deserializion, if the object to persist already exists, local fields
548    will only be updated, if their name is in this set.
549    """
550
551
552    @classmethod
553    def convert_human_readable_values(cls, data, to_human_readable=False):
554        """\
555        Performs conversions on user-supplied field data, to make it
556        easier for users to pass human-readable data.
557
558        For all fields that have choice sets, convert their values
559        from human-readable strings to enum values, if necessary.  This
560        allows users to pass strings instead of the corresponding
561        integer values.
562
563        For all foreign key fields, call smart_get with the supplied
564        data.  This allows the user to pass either an ID value or
565        the name of the object as a string.
566
567        If to_human_readable=True, perform the inverse - i.e. convert
568        numeric values to human readable values.
569
570        This method modifies data in-place.
571        """
572        field_dict = cls.get_field_dict()
573        for field_name in data:
574            if field_name not in field_dict or data[field_name] is None:
575                continue
576            field_obj = field_dict[field_name]
577            # convert enum values
578            if field_obj.choices:
579                for choice_data in field_obj.choices:
580                    # choice_data is (value, name)
581                    if to_human_readable:
582                        from_val, to_val = choice_data
583                    else:
584                        to_val, from_val = choice_data
585                    if from_val == data[field_name]:
586                        data[field_name] = to_val
587                        break
588            # convert foreign key values
589            elif field_obj.rel:
590                dest_obj = field_obj.rel.to.smart_get(data[field_name],
591                                                      valid_only=False)
592                if to_human_readable:
593                    # parameterized_jobs do not have a name_field
594                    if (field_name != 'parameterized_job' and
595                        dest_obj.name_field is not None):
596                        data[field_name] = getattr(dest_obj,
597                                                   dest_obj.name_field)
598                else:
599                    data[field_name] = dest_obj
600
601
602
603
604    def _validate_unique(self):
605        """\
606        Validate that unique fields are unique.  Django manipulators do
607        this too, but they're a huge pain to use manually.  Trust me.
608        """
609        errors = {}
610        cls = type(self)
611        field_dict = self.get_field_dict()
612        manager = cls.get_valid_manager()
613        for field_name, field_obj in six.iteritems(field_dict):
614            if not field_obj.unique:
615                continue
616
617            value = getattr(self, field_name)
618            if value is None and field_obj.auto_created:
619                # don't bother checking autoincrement fields about to be
620                # generated
621                continue
622
623            existing_objs = manager.filter(**{field_name : value})
624            num_existing = existing_objs.count()
625
626            if num_existing == 0:
627                continue
628            if num_existing == 1 and existing_objs[0].id == self.id:
629                continue
630            errors[field_name] = (
631                'This value must be unique (%s)' % (value))
632        return errors
633
634
635    def _validate(self):
636        """
637        First coerces all fields on this instance to their proper Python types.
638        Then runs validation on every field. Returns a dictionary of
639        field_name -> error_list.
640
641        Based on validate() from django.db.models.Model in Django 0.96, which
642        was removed in Django 1.0. It should reappear in a later version. See:
643            http://code.djangoproject.com/ticket/6845
644        """
645        error_dict = {}
646        for f in self._meta.fields:
647            try:
648                python_value = f.to_python(
649                    getattr(self, f.attname, f.get_default()))
650            except django.core.exceptions.ValidationError as e:
651                error_dict[f.name] = str(e)
652                continue
653
654            if not f.blank and not python_value:
655                error_dict[f.name] = 'This field is required.'
656                continue
657
658            setattr(self, f.attname, python_value)
659
660        return error_dict
661
662
663    def do_validate(self):
664        """Validate fields."""
665        errors = self._validate()
666        unique_errors = self._validate_unique()
667        for field_name, error in six.iteritems(unique_errors):
668            errors.setdefault(field_name, error)
669        if errors:
670            raise ValidationError(errors)
671
672
673    # actually (externally) useful methods follow
674
675    @classmethod
676    def add_object(cls, data={}, **kwargs):
677        """\
678        Returns a new object created with the given data (a dictionary
679        mapping field names to values). Merges any extra keyword args
680        into data.
681        """
682        data = dict(data)
683        data.update(kwargs)
684        data = cls.prepare_data_args(data)
685        cls.convert_human_readable_values(data)
686        data = cls.provide_default_values(data)
687
688        obj = cls(**data)
689        obj.do_validate()
690        obj.save()
691        return obj
692
693
694    def update_object(self, data={}, **kwargs):
695        """\
696        Updates the object with the given data (a dictionary mapping
697        field names to values).  Merges any extra keyword args into
698        data.
699        """
700        data = dict(data)
701        data.update(kwargs)
702        data = self.prepare_data_args(data)
703        self.convert_human_readable_values(data)
704        for field_name, value in six.iteritems(data):
705            setattr(self, field_name, value)
706        self.do_validate()
707        self.save()
708
709
710    # see query_objects()
711    _SPECIAL_FILTER_KEYS = ('query_start', 'query_limit', 'sort_by',
712                            'extra_args', 'extra_where', 'no_distinct')
713
714
715    @classmethod
716    def _extract_special_params(cls, filter_data):
717        """
718        @returns a tuple of dicts (special_params, regular_filters), where
719        special_params contains the parameters we handle specially and
720        regular_filters is the remaining data to be handled by Django.
721        """
722        regular_filters = dict(filter_data)
723        special_params = {}
724        for key in cls._SPECIAL_FILTER_KEYS:
725            if key in regular_filters:
726                special_params[key] = regular_filters.pop(key)
727        return special_params, regular_filters
728
729
730    @classmethod
731    def apply_presentation(cls, query, filter_data):
732        """
733        Apply presentation parameters -- sorting and paging -- to the given
734        query.
735        @returns new query with presentation applied
736        """
737        special_params, _ = cls._extract_special_params(filter_data)
738        sort_by = special_params.get('sort_by', None)
739        if sort_by:
740            assert isinstance(sort_by, list) or isinstance(sort_by, tuple)
741            query = query.extra(order_by=sort_by)
742
743        query_start = special_params.get('query_start', None)
744        query_limit = special_params.get('query_limit', None)
745        if query_start is not None:
746            if query_limit is None:
747                raise ValueError('Cannot pass query_start without query_limit')
748            # query_limit is passed as a page size
749            query_limit += query_start
750        return query[query_start:query_limit]
751
752
753    @classmethod
754    def query_objects(cls, filter_data, valid_only=True, initial_query=None,
755                      apply_presentation=True):
756        """\
757        Returns a QuerySet object for querying the given model_class
758        with the given filter_data.  Optional special arguments in
759        filter_data include:
760        -query_start: index of first return to return
761        -query_limit: maximum number of results to return
762        -sort_by: list of fields to sort on.  prefixing a '-' onto a
763         field name changes the sort to descending order.
764        -extra_args: keyword args to pass to query.extra() (see Django
765         DB layer documentation)
766        -extra_where: extra WHERE clause to append
767        -no_distinct: if True, a DISTINCT will not be added to the SELECT
768        """
769        special_params, regular_filters = cls._extract_special_params(
770                filter_data)
771
772        if initial_query is None:
773            if valid_only:
774                initial_query = cls.get_valid_manager()
775            else:
776                initial_query = cls.objects
777
778        query = initial_query.filter(**regular_filters)
779
780        use_distinct = not special_params.get('no_distinct', False)
781        if use_distinct:
782            query = query.distinct()
783
784        extra_args = special_params.get('extra_args', {})
785        extra_where = special_params.get('extra_where', None)
786        if extra_where:
787            # escape %'s
788            extra_where = cls.objects.escape_user_sql(extra_where)
789            extra_args.setdefault('where', []).append(extra_where)
790        if extra_args:
791            query = query.extra(**extra_args)
792            # TODO: Use readonly connection for these queries.
793            # This has been disabled, because it's not used anyway, as the
794            # configured readonly user is the same as the real user anyway.
795
796        if apply_presentation:
797            query = cls.apply_presentation(query, filter_data)
798
799        return query
800
801
802    @classmethod
803    def query_count(cls, filter_data, initial_query=None):
804        """\
805        Like query_objects, but retreive only the count of results.
806        """
807        filter_data.pop('query_start', None)
808        filter_data.pop('query_limit', None)
809        query = cls.query_objects(filter_data, initial_query=initial_query)
810        return query.count()
811
812
813    @classmethod
814    def clean_object_dicts(cls, field_dicts):
815        """\
816        Take a list of dicts corresponding to object (as returned by
817        query.values()) and clean the data to be more suitable for
818        returning to the user.
819        """
820        for field_dict in field_dicts:
821            cls.clean_foreign_keys(field_dict)
822            cls._convert_booleans(field_dict)
823            cls.convert_human_readable_values(field_dict,
824                                              to_human_readable=True)
825
826
827    @classmethod
828    def list_objects(cls, filter_data, initial_query=None):
829        """\
830        Like query_objects, but return a list of dictionaries.
831        """
832        query = cls.query_objects(filter_data, initial_query=initial_query)
833        extra_fields = query.query.extra_select.keys()
834        field_dicts = [model_object.get_object_dict(extra_fields=extra_fields)
835                       for model_object in query]
836        return field_dicts
837
838
839    @classmethod
840    def smart_get(cls, id_or_name, valid_only=True):
841        """\
842        smart_get(integer) -> get object by ID
843        smart_get(string) -> get object by name_field
844        """
845        if valid_only:
846            manager = cls.get_valid_manager()
847        else:
848            manager = cls.objects
849
850        if isinstance(id_or_name, six.integer_types):
851            return manager.get(pk=id_or_name)
852        if isinstance(id_or_name, six.string_types) and hasattr(
853                cls, 'name_field'):
854            return manager.get(**{cls.name_field : id_or_name})
855        raise ValueError(
856            'Invalid positional argument: %s (%s)' % (id_or_name,
857                                                      type(id_or_name)))
858
859
860    @classmethod
861    def smart_get_bulk(cls, id_or_name_list):
862        """Like smart_get, but for a list of ids or names"""
863        invalid_inputs = []
864        result_objects = []
865        for id_or_name in id_or_name_list:
866            try:
867                result_objects.append(cls.smart_get(id_or_name))
868            except cls.DoesNotExist:
869                invalid_inputs.append(id_or_name)
870        if invalid_inputs:
871            raise cls.DoesNotExist('The following %ss do not exist: %s'
872                                   % (cls.__name__.lower(),
873                                      ', '.join(invalid_inputs)))
874        return result_objects
875
876
877    def get_object_dict(self, extra_fields=None):
878        """\
879        Return a dictionary mapping fields to this object's values.  @param
880        extra_fields: list of extra attribute names to include, in addition to
881        the fields defined on this object.
882        """
883        fields = self.get_field_dict().keys()
884        if extra_fields:
885            fields += extra_fields
886        object_dict = dict((field_name, getattr(self, field_name))
887                           for field_name in fields)
888        self.clean_object_dicts([object_dict])
889        self._postprocess_object_dict(object_dict)
890        return object_dict
891
892
893    def _postprocess_object_dict(self, object_dict):
894        """For subclasses to override."""
895        pass
896
897
898    @classmethod
899    def get_valid_manager(cls):
900        return cls.objects
901
902
903    def _record_attributes(self, attributes):
904        """
905        See on_attribute_changed.
906        """
907        assert not isinstance(attributes, six.string_types)
908        self._recorded_attributes = dict((attribute, getattr(self, attribute))
909                                         for attribute in attributes)
910
911
912    def _check_for_updated_attributes(self):
913        """
914        See on_attribute_changed.
915        """
916        for attribute, original_value in six.iteritems(
917                self._recorded_attributes):
918            new_value = getattr(self, attribute)
919            if original_value != new_value:
920                self.on_attribute_changed(attribute, original_value)
921        self._record_attributes(self._recorded_attributes.keys())
922
923
924    def on_attribute_changed(self, attribute, old_value):
925        """
926        Called whenever an attribute is updated.  To be overridden.
927
928        To use this method, you must:
929        * call _record_attributes() from __init__() (after making the super
930        call) with a list of attributes for which you want to be notified upon
931        change.
932        * call _check_for_updated_attributes() from save().
933        """
934        pass
935
936
937    def serialize(self, include_dependencies=True):
938        """Serializes the object with dependencies.
939
940        The variable SERIALIZATION_LINKS_TO_FOLLOW defines which dependencies
941        this function will serialize with the object.
942
943        @param include_dependencies: Whether or not to follow relations to
944                                     objects this object depends on.
945                                     This parameter is used when uploading
946                                     jobs from a shard to the main, as the
947                                     main already has all the dependent
948                                     objects.
949
950        @returns: Dictionary representation of the object.
951        """
952        serialized = {}
953        for field in self._meta.concrete_model._meta.local_fields:
954            if field.rel is None:
955                serialized[field.name] = field._get_val_from_obj(self)
956            elif field.name in self.SERIALIZATION_LINKS_TO_KEEP:
957                # attname will contain "_id" suffix for foreign keys,
958                # e.g. HostAttribute.host will be serialized as 'host_id'.
959                # Use it for easy deserialization.
960                serialized[field.attname] = field._get_val_from_obj(self)
961
962        if include_dependencies:
963            for link in self.SERIALIZATION_LINKS_TO_FOLLOW:
964                serialized[link] = self._serialize_relation(link)
965
966        return serialized
967
968
969    def _serialize_relation(self, link):
970        """Serializes dependent objects given the name of the relation.
971
972        @param link: Name of the relation to take objects from.
973
974        @returns For To-Many relationships a list of the serialized related
975            objects, for To-One relationships the serialized related object.
976        """
977        try:
978            attr = getattr(self, link)
979        except AttributeError:
980            # One-To-One relationships that point to None may raise this
981            return None
982
983        if attr is None:
984            return None
985        if hasattr(attr, 'all'):
986            return [obj.serialize() for obj in attr.all()]
987        return attr.serialize()
988
989
990    @classmethod
991    def _split_local_from_foreign_values(cls, data):
992        """This splits local from foreign values in a serialized object.
993
994        @param data: The serialized object.
995
996        @returns A tuple of two lists, both containing tuples in the form
997                 (link_name, link_value). The first list contains all links
998                 for local fields, the second one contains those for foreign
999                 fields/objects.
1000        """
1001        links_to_local_values, links_to_related_values = [], []
1002        for link, value in six.iteritems(data):
1003            if link in cls.SERIALIZATION_LINKS_TO_FOLLOW:
1004                # It's a foreign key
1005                links_to_related_values.append((link, value))
1006            else:
1007                # It's a local attribute or a foreign key
1008                # we don't want to follow.
1009                links_to_local_values.append((link, value))
1010        return links_to_local_values, links_to_related_values
1011
1012
1013    @classmethod
1014    def _filter_update_allowed_fields(cls, data):
1015        """Filters data and returns only files that updates are allowed on.
1016
1017        This is i.e. needed for syncing aborted bits from the main to shards.
1018
1019        Local links are only allowed to be updated, if they are in
1020        SERIALIZATION_LOCAL_LINKS_TO_UPDATE.
1021        Overwriting existing values is allowed in order to be able to sync i.e.
1022        the aborted bit from the main to a shard.
1023
1024        The allowlisting mechanism is in place to prevent overwriting local
1025        status: If all fields were overwritten, jobs would be completely be
1026        set back to their original (unstarted) state.
1027
1028        @param data: List with tuples of the form (link_name, link_value), as
1029                     returned by _split_local_from_foreign_values.
1030
1031        @returns List of the same format as data, but only containing data for
1032                 fields that updates are allowed on.
1033        """
1034        return [pair for pair in data
1035                if pair[0] in cls.SERIALIZATION_LOCAL_LINKS_TO_UPDATE]
1036
1037
1038    @classmethod
1039    def delete_matching_record(cls, **filter_args):
1040        """Delete records matching the filter.
1041
1042        @param filter_args: Arguments for the django filter
1043                used to locate the record to delete.
1044        """
1045        try:
1046            existing_record = cls.objects.get(**filter_args)
1047        except cls.DoesNotExist:
1048            return
1049        existing_record.delete()
1050
1051
1052    def _deserialize_local(self, data):
1053        """Set local attributes from a list of tuples.
1054
1055        @param data: List of tuples like returned by
1056                     _split_local_from_foreign_values.
1057        """
1058        if not data:
1059            return
1060
1061        for link, value in data:
1062            setattr(self, link, value)
1063        # Overwridden save() methods are prone to errors, so don't execute them.
1064        # This is because:
1065        # - the overwritten methods depend on ACL groups that don't yet exist
1066        #   and don't handle errors
1067        # - the overwritten methods think this object already exists in the db
1068        #   because the id is already set
1069        super(type(self), self).save()
1070
1071
1072    def _deserialize_relations(self, data):
1073        """Set foreign attributes from a list of tuples.
1074
1075        This deserialized the related objects using their own deserialize()
1076        function and then sets the relation.
1077
1078        @param data: List of tuples like returned by
1079                     _split_local_from_foreign_values.
1080        """
1081        for link, value in data:
1082            self._deserialize_relation(link, value)
1083        # See comment in _deserialize_local
1084        super(type(self), self).save()
1085
1086
1087    @classmethod
1088    def get_record(cls, data):
1089        """Retrieve a record with the data in the given input arg.
1090
1091        @param data: A dictionary containing the information to use in a query
1092                for data. If child models have different constraints of
1093                uniqueness they should override this model.
1094
1095        @return: An object with matching data.
1096
1097        @raises DoesNotExist: If a record with the given data doesn't exist.
1098        """
1099        return cls.objects.get(id=data['id'])
1100
1101
1102    @classmethod
1103    def deserialize(cls, data):
1104        """Recursively deserializes and saves an object with it's dependencies.
1105
1106        This takes the result of the serialize method and creates objects
1107        in the database that are just like the original.
1108
1109        If an object of the same type with the same id already exists, it's
1110        local values will be left untouched, unless they are explicitly
1111        allowlisted in SERIALIZATION_LOCAL_LINKS_TO_UPDATE.
1112
1113        Deserialize will always recursively propagate to all related objects
1114        present in data though.
1115        I.e. this is necessary to add users to an already existing acl-group.
1116
1117        @param data: Representation of an object and its dependencies, as
1118                     returned by serialize.
1119
1120        @returns: The object represented by data if it didn't exist before,
1121                  otherwise the object that existed before and has the same type
1122                  and id as the one described by data.
1123        """
1124        if data is None:
1125            return None
1126
1127        local, related = cls._split_local_from_foreign_values(data)
1128        try:
1129            instance = cls.get_record(data)
1130            local = cls._filter_update_allowed_fields(local)
1131        except cls.DoesNotExist:
1132            instance = cls()
1133
1134        instance._deserialize_local(local)
1135        instance._deserialize_relations(related)
1136
1137        return instance
1138
1139
1140    def _check_update_from_shard(self, shard, updated_serialized,
1141                                       *args, **kwargs):
1142        """Check if an update sent from a shard is legitimate.
1143
1144        @raises error.UnallowedRecordsSentToMain if an update is not
1145                legitimate.
1146        """
1147        raise NotImplementedError(
1148            '_check_update_from_shard must be implemented by subclass %s '
1149            'for type %s' % type(self))
1150
1151
1152    @transaction.commit_on_success
1153    def update_from_serialized(self, serialized):
1154        """Updates local fields of an existing object from a serialized form.
1155
1156        This is different than the normal deserialize() in the way that it
1157        does update local values, which deserialize doesn't, but doesn't
1158        recursively propagate to related objects, which deserialize() does.
1159
1160        The use case of this function is to update job records on the main
1161        after the jobs have been executed on a shard, as the main is not
1162        interested in updates for users, labels, specialtasks, etc.
1163
1164        @param serialized: Representation of an object and its dependencies, as
1165                           returned by serialize.
1166
1167        @raises ValueError: if serialized contains related objects, i.e. not
1168                            only local fields.
1169        """
1170        local, related = (
1171            self._split_local_from_foreign_values(serialized))
1172        if related:
1173            raise ValueError('Serialized must not contain foreign '
1174                             'objects: %s' % related)
1175
1176        self._deserialize_local(local)
1177
1178
1179    def custom_deserialize_relation(self, link, data):
1180        """Allows overriding the deserialization behaviour by subclasses."""
1181        raise NotImplementedError(
1182            'custom_deserialize_relation must be implemented by subclass %s '
1183            'for relation %s' % (type(self), link))
1184
1185
1186    def _deserialize_relation(self, link, data):
1187        """Deserializes related objects and sets references on this object.
1188
1189        Relations that point to a list of objects are handled automatically.
1190        For many-to-one or one-to-one relations custom_deserialize_relation
1191        must be overridden by the subclass.
1192
1193        Related objects are deserialized using their deserialize() method.
1194        Thereby they and their dependencies are created if they don't exist
1195        and saved to the database.
1196
1197        @param link: Name of the relation.
1198        @param data: Serialized representation of the related object(s).
1199                     This means a list of dictionaries for to-many relations,
1200                     just a dictionary for to-one relations.
1201        """
1202        field = getattr(self, link)
1203
1204        if field and hasattr(field, 'all'):
1205            self._deserialize_2m_relation(link, data, field.model)
1206        else:
1207            self.custom_deserialize_relation(link, data)
1208
1209
1210    def _deserialize_2m_relation(self, link, data, related_class):
1211        """Deserialize related objects for one to-many relationship.
1212
1213        @param link: Name of the relation.
1214        @param data: Serialized representation of the related objects.
1215                     This is a list with of dictionaries.
1216        @param related_class: A class representing a django model, with which
1217                              this class has a one-to-many relationship.
1218        """
1219        relation_set = getattr(self, link)
1220        if related_class == self.get_attribute_model():
1221            # When deserializing a model together with
1222            # its attributes, clear all the exising attributes to ensure
1223            # db consistency. Note 'update' won't be sufficient, as we also
1224            # want to remove any attributes that no longer exist in |data|.
1225            #
1226            # core_filters is a dictionary of filters, defines how
1227            # RelatedMangager would query for the 1-to-many relationship. E.g.
1228            # Host.objects.get(
1229            #     id=20).hostattribute_set.core_filters = {host_id:20}
1230            # We use it to delete objects related to the current object.
1231            related_class.objects.filter(**relation_set.core_filters).delete()
1232        for serialized in data:
1233            relation_set.add(related_class.deserialize(serialized))
1234
1235
1236    @classmethod
1237    def get_attribute_model(cls):
1238        """Return the attribute model.
1239
1240        Subclass with attribute-like model should override this to
1241        return the attribute model class. This method will be
1242        called by _deserialize_2m_relation to determine whether
1243        to clear the one-to-many relations first on deserialization of object.
1244        """
1245        return None
1246
1247
1248class ModelWithInvalid(ModelExtensions):
1249    """
1250    Overrides model methods save() and delete() to support invalidation in
1251    place of actual deletion.  Subclasses must have a boolean "invalid"
1252    field.
1253    """
1254
1255    def save(self, *args, **kwargs):
1256        """Saves the model"""
1257        first_time = (self.id is None)
1258        if first_time:
1259            # see if this object was previously added and invalidated
1260            my_name = getattr(self, self.name_field)
1261            filters = {self.name_field : my_name, 'invalid' : True}
1262            try:
1263                old_object = self.__class__.objects.get(**filters)
1264                self.resurrect_object(old_object)
1265            except self.DoesNotExist:
1266                # no existing object
1267                pass
1268
1269        super(ModelWithInvalid, self).save(*args, **kwargs)
1270
1271
1272    def resurrect_object(self, old_object):
1273        """
1274        Called when self is about to be saved for the first time and is actually
1275        "undeleting" a previously deleted object.  Can be overridden by
1276        subclasses to copy data as desired from the deleted entry (but this
1277        superclass implementation must normally be called).
1278        """
1279        self.id = old_object.id
1280
1281
1282    def clean_object(self):
1283        """
1284        This method is called when an object is marked invalid.
1285        Subclasses should override this to clean up relationships that
1286        should no longer exist if the object were deleted.
1287        """
1288        pass
1289
1290
1291    def delete(self):
1292        """Deletes the model"""
1293        self.invalid = self.invalid
1294        assert not self.invalid
1295        self.invalid = True
1296        self.save()
1297        self.clean_object()
1298
1299
1300    @classmethod
1301    def get_valid_manager(cls):
1302        return cls.valid_objects
1303
1304
1305    class Manipulator(object):
1306        """
1307        Force default manipulators to look only at valid objects -
1308        otherwise they will match against invalid objects when checking
1309        uniqueness.
1310        """
1311        @classmethod
1312        def _prepare(cls, model):
1313            super(ModelWithInvalid.Manipulator, cls)._prepare(model)
1314            cls.manager = model.valid_objects
1315
1316
1317class ModelWithAttributes(object):
1318    """
1319    Mixin class for models that have an attribute model associated with them.
1320    The attribute model is assumed to have its value field named "value".
1321    """
1322
1323    def _get_attribute_model_and_args(self, attribute):
1324        """
1325        Subclasses should override this to return a tuple (attribute_model,
1326        keyword_args), where attribute_model is a model class and keyword_args
1327        is a dict of args to pass to attribute_model.objects.get() to get an
1328        instance of the given attribute on this object.
1329        """
1330        raise NotImplementedError
1331
1332
1333    def _is_replaced_by_static_attribute(self, attribute):
1334        """
1335        Subclasses could override this to indicate whether it has static
1336        attributes.
1337        """
1338        return False
1339
1340
1341    def set_attribute(self, attribute, value):
1342        if self._is_replaced_by_static_attribute(attribute):
1343            raise error.UnmodifiableAttributeException(
1344                    'Failed to set attribute "%s" for host "%s" since it '
1345                    'is static. Use go/chromeos-skylab-inventory-tools to '
1346                    'modify this attribute.' % (attribute, self.hostname))
1347
1348        attribute_model, get_args = self._get_attribute_model_and_args(
1349            attribute)
1350        attribute_object, _ = attribute_model.objects.get_or_create(**get_args)
1351        attribute_object.value = value
1352        attribute_object.save()
1353
1354
1355    def delete_attribute(self, attribute):
1356        """Deletes an attribute"""
1357        if self._is_replaced_by_static_attribute(attribute):
1358            raise error.UnmodifiableAttributeException(
1359                    'Failed to delete attribute "%s" for host "%s" since it '
1360                    'is static. Use go/chromeos-skylab-inventory-tools to '
1361                    'modify this attribute.' % (attribute, self.hostname))
1362
1363        attribute_model, get_args = self._get_attribute_model_and_args(
1364            attribute)
1365        try:
1366            attribute_model.objects.get(**get_args).delete()
1367        except attribute_model.DoesNotExist:
1368            pass
1369
1370
1371    def set_or_delete_attribute(self, attribute, value):
1372        if value is None:
1373            self.delete_attribute(attribute)
1374        else:
1375            self.set_attribute(attribute, value)
1376
1377
1378class ModelWithHashManager(dbmodels.Manager):
1379    """Manager for use with the ModelWithHash abstract model class"""
1380
1381    def create(self, **kwargs):
1382        """Always raises exception."""
1383        raise Exception('ModelWithHash manager should use get_or_create() '
1384                        'instead of create()')
1385
1386
1387    def get_or_create(self, **kwargs):
1388        kwargs['the_hash'] = self.model._compute_hash(**kwargs)
1389        return super(ModelWithHashManager, self).get_or_create(**kwargs)
1390
1391
1392class ModelWithHash(dbmodels.Model):
1393    """Superclass with methods for dealing with a hash column"""
1394
1395    the_hash = dbmodels.CharField(max_length=40, unique=True)
1396
1397    objects = ModelWithHashManager()
1398
1399    class Meta:
1400        """Overrides dbmodels.Model.Meta."""
1401        abstract = True
1402
1403
1404    @classmethod
1405    def _compute_hash(cls, **kwargs):
1406        raise NotImplementedError('Subclasses must override _compute_hash()')
1407
1408
1409    def save(self, force_insert=False, **kwargs):
1410        """Prevents saving the model in most cases
1411
1412        We want these models to be immutable, so the generic save() operation
1413        will not work. These models should be instantiated through their the
1414        model.objects.get_or_create() method instead.
1415
1416        The exception is that save(force_insert=True) will be allowed, since
1417        that creates a new row. However, the preferred way to make instances of
1418        these models is through the get_or_create() method.
1419        """
1420        if not force_insert:
1421            # Allow a forced insert to happen; if it's a duplicate, the unique
1422            # constraint will catch it later anyways
1423            raise Exception('ModelWithHash is immutable')
1424        super(ModelWithHash, self).save(force_insert=force_insert, **kwargs)
1425