• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1"""
2Extensions to Django's model logic.
3"""
4
5import django.core.exceptions
6from django.db import connection
7from django.db import connections
8from django.db import models as dbmodels
9from django.db import transaction
10from django.db.models.sql import query
11import django.db.models.sql.where
12from autotest_lib.frontend.afe import rdb_model_extensions
13
14
15class ValidationError(django.core.exceptions.ValidationError):
16    """\
17    Data validation error in adding or updating an object. The associated
18    value is a dictionary mapping field names to error strings.
19    """
20
21def _quote_name(name):
22    """Shorthand for connection.ops.quote_name()."""
23    return connection.ops.quote_name(name)
24
25
26class LeasedHostManager(dbmodels.Manager):
27    """Query manager for unleased, unlocked hosts.
28    """
29    def get_query_set(self):
30        return (super(LeasedHostManager, self).get_query_set().filter(
31                leased=0, locked=0))
32
33
34class ExtendedManager(dbmodels.Manager):
35    """\
36    Extended manager supporting subquery filtering.
37    """
38
39    class CustomQuery(query.Query):
40        def __init__(self, *args, **kwargs):
41            super(ExtendedManager.CustomQuery, self).__init__(*args, **kwargs)
42            self._custom_joins = []
43
44
45        def clone(self, klass=None, **kwargs):
46            obj = super(ExtendedManager.CustomQuery, self).clone(klass)
47            obj._custom_joins = list(self._custom_joins)
48            return obj
49
50
51        def combine(self, rhs, connector):
52            super(ExtendedManager.CustomQuery, self).combine(rhs, connector)
53            if hasattr(rhs, '_custom_joins'):
54                self._custom_joins.extend(rhs._custom_joins)
55
56
57        def add_custom_join(self, table, condition, join_type,
58                            condition_values=(), alias=None):
59            if alias is None:
60                alias = table
61            join_dict = dict(table=table,
62                             condition=condition,
63                             condition_values=condition_values,
64                             join_type=join_type,
65                             alias=alias)
66            self._custom_joins.append(join_dict)
67
68
69        @classmethod
70        def convert_query(self, query_set):
71            """
72            Convert the query set's "query" attribute to a CustomQuery.
73            """
74            # Make a copy of the query set
75            query_set = query_set.all()
76            query_set.query = query_set.query.clone(
77                    klass=ExtendedManager.CustomQuery,
78                    _custom_joins=[])
79            return query_set
80
81
82    class _WhereClause(object):
83        """Object allowing us to inject arbitrary SQL into Django queries.
84
85        By using this instead of extra(where=...), we can still freely combine
86        queries with & and |.
87        """
88        def __init__(self, clause, values=()):
89            self._clause = clause
90            self._values = values
91
92
93        def as_sql(self, qn=None, connection=None):
94            return self._clause, self._values
95
96
97        def relabel_aliases(self, change_map):
98            return
99
100
101    def add_join(self, query_set, join_table, join_key, join_condition='',
102                 join_condition_values=(), join_from_key=None, alias=None,
103                 suffix='', exclude=False, force_left_join=False):
104        """Add a join to query_set.
105
106        Join looks like this:
107                (INNER|LEFT) JOIN <join_table> AS <alias>
108                    ON (<this table>.<join_from_key> = <join_table>.<join_key>
109                        and <join_condition>)
110
111        @param join_table table to join to
112        @param join_key field referencing back to this model to use for the join
113        @param join_condition extra condition for the ON clause of the join
114        @param join_condition_values values to substitute into join_condition
115        @param join_from_key column on this model to join from.
116        @param alias alias to use for for join
117        @param suffix suffix to add to join_table for the join alias, if no
118                alias is provided
119        @param exclude if true, exclude rows that match this join (will use a
120        LEFT OUTER JOIN and an appropriate WHERE condition)
121        @param force_left_join - if true, a LEFT OUTER JOIN will be used
122        instead of an INNER JOIN regardless of other options
123        """
124        join_from_table = query_set.model._meta.db_table
125        if join_from_key is None:
126            join_from_key = self.model._meta.pk.name
127        if alias is None:
128            alias = join_table + suffix
129        full_join_key = _quote_name(alias) + '.' + _quote_name(join_key)
130        full_join_condition = '%s = %s.%s' % (full_join_key,
131                                              _quote_name(join_from_table),
132                                              _quote_name(join_from_key))
133        if join_condition:
134            full_join_condition += ' AND (' + join_condition + ')'
135        if exclude or force_left_join:
136            join_type = query_set.query.LOUTER
137        else:
138            join_type = query_set.query.INNER
139
140        query_set = self.CustomQuery.convert_query(query_set)
141        query_set.query.add_custom_join(join_table,
142                                        full_join_condition,
143                                        join_type,
144                                        condition_values=join_condition_values,
145                                        alias=alias)
146
147        if exclude:
148            query_set = query_set.extra(where=[full_join_key + ' IS NULL'])
149
150        return query_set
151
152
153    def _info_for_many_to_one_join(self, field, join_to_query, alias):
154        """
155        @param field: the ForeignKey field on the related model
156        @param join_to_query: the query over the related model that we're
157                joining to
158        @param alias: alias of joined table
159        """
160        info = {}
161        rhs_table = join_to_query.model._meta.db_table
162        info['rhs_table'] = rhs_table
163        info['rhs_column'] = field.column
164        info['lhs_column'] = field.rel.get_related_field().column
165        rhs_where = join_to_query.query.where
166        rhs_where.relabel_aliases({rhs_table: alias})
167        compiler = join_to_query.query.get_compiler(using=join_to_query.db)
168        initial_clause, values = compiler.as_sql()
169        # initial_clause is compiled from `join_to_query`, which is a SELECT
170        # query returns at most one record. For it to be used in WHERE clause,
171        # it must be converted to a boolean value using EXISTS.
172        all_clauses = ('EXISTS (%s)' % initial_clause,)
173        if hasattr(join_to_query.query, 'extra_where'):
174            all_clauses += join_to_query.query.extra_where
175        info['where_clause'] = (
176                    ' AND '.join('(%s)' % clause for clause in all_clauses))
177        info['values'] = values
178        return info
179
180
181    def _info_for_many_to_many_join(self, m2m_field, join_to_query, alias,
182                                    m2m_is_on_this_model):
183        """
184        @param m2m_field: a Django field representing the M2M relationship.
185                It uses a pivot table with the following structure:
186                this model table <---> M2M pivot table <---> joined model table
187        @param join_to_query: the query over the related model that we're
188                joining to.
189        @param alias: alias of joined table
190        """
191        if m2m_is_on_this_model:
192            # referenced field on this model
193            lhs_id_field = self.model._meta.pk
194            # foreign key on the pivot table referencing lhs_id_field
195            m2m_lhs_column = m2m_field.m2m_column_name()
196            # foreign key on the pivot table referencing rhd_id_field
197            m2m_rhs_column = m2m_field.m2m_reverse_name()
198            # referenced field on related model
199            rhs_id_field = m2m_field.rel.get_related_field()
200        else:
201            lhs_id_field = m2m_field.rel.get_related_field()
202            m2m_lhs_column = m2m_field.m2m_reverse_name()
203            m2m_rhs_column = m2m_field.m2m_column_name()
204            rhs_id_field = join_to_query.model._meta.pk
205
206        info = {}
207        info['rhs_table'] = m2m_field.m2m_db_table()
208        info['rhs_column'] = m2m_lhs_column
209        info['lhs_column'] = lhs_id_field.column
210
211        # select the ID of related models relevant to this join.  we can only do
212        # a single join, so we need to gather this information up front and
213        # include it in the join condition.
214        rhs_ids = join_to_query.values_list(rhs_id_field.attname, flat=True)
215        assert len(rhs_ids) == 1, ('Many-to-many custom field joins can only '
216                                   'match a single related object.')
217        rhs_id = rhs_ids[0]
218
219        info['where_clause'] = '%s.%s = %s' % (_quote_name(alias),
220                                               _quote_name(m2m_rhs_column),
221                                               rhs_id)
222        info['values'] = ()
223        return info
224
225
226    def join_custom_field(self, query_set, join_to_query, alias,
227                          left_join=True):
228        """Join to a related model to create a custom field in the given query.
229
230        This method is used to construct a custom field on the given query based
231        on a many-valued relationsip.  join_to_query should be a simple query
232        (no joins) on the related model which returns at most one related row
233        per instance of this model.
234
235        For many-to-one relationships, the joined table contains the matching
236        row from the related model it one is related, NULL otherwise.
237
238        For many-to-many relationships, the joined table contains the matching
239        row if it's related, NULL otherwise.
240        """
241        relationship_type, field = self.determine_relationship(
242                join_to_query.model)
243
244        if relationship_type == self.MANY_TO_ONE:
245            info = self._info_for_many_to_one_join(field, join_to_query, alias)
246        elif relationship_type == self.M2M_ON_RELATED_MODEL:
247            info = self._info_for_many_to_many_join(
248                    m2m_field=field, join_to_query=join_to_query, alias=alias,
249                    m2m_is_on_this_model=False)
250        elif relationship_type ==self.M2M_ON_THIS_MODEL:
251            info = self._info_for_many_to_many_join(
252                    m2m_field=field, join_to_query=join_to_query, alias=alias,
253                    m2m_is_on_this_model=True)
254
255        return self.add_join(query_set, info['rhs_table'], info['rhs_column'],
256                             join_from_key=info['lhs_column'],
257                             join_condition=info['where_clause'],
258                             join_condition_values=info['values'],
259                             alias=alias,
260                             force_left_join=left_join)
261
262
263    def add_where(self, query_set, where, values=()):
264        query_set = query_set.all()
265        query_set.query.where.add(self._WhereClause(where, values),
266                                  django.db.models.sql.where.AND)
267        return query_set
268
269
270    def _get_quoted_field(self, table, field):
271        return _quote_name(table) + '.' + _quote_name(field)
272
273
274    def get_key_on_this_table(self, key_field=None):
275        if key_field is None:
276            # default to primary key
277            key_field = self.model._meta.pk.column
278        return self._get_quoted_field(self.model._meta.db_table, key_field)
279
280
281    def escape_user_sql(self, sql):
282        return sql.replace('%', '%%')
283
284
285    def _custom_select_query(self, query_set, selects):
286        """Execute a custom select query.
287
288        @param query_set: query set as returned by query_objects.
289        @param selects: Tables/Columns to select, e.g. tko_test_labels_list.id.
290
291        @returns: Result of the query as returned by cursor.fetchall().
292        """
293        compiler = query_set.query.get_compiler(using=query_set.db)
294        sql, params = compiler.as_sql()
295        from_ = sql[sql.find(' FROM'):]
296
297        if query_set.query.distinct:
298            distinct = 'DISTINCT '
299        else:
300            distinct = ''
301
302        sql_query = ('SELECT ' + distinct + ','.join(selects) + from_)
303        # Chose the connection that's responsible for this type of object
304        cursor = connections[query_set.db].cursor()
305        cursor.execute(sql_query, params)
306        return cursor.fetchall()
307
308
309    def _is_relation_to(self, field, model_class):
310        return field.rel and field.rel.to is model_class
311
312
313    MANY_TO_ONE = object()
314    M2M_ON_RELATED_MODEL = object()
315    M2M_ON_THIS_MODEL = object()
316
317    def determine_relationship(self, related_model):
318        """
319        Determine the relationship between this model and related_model.
320
321        related_model must have some sort of many-valued relationship to this
322        manager's model.
323        @returns (relationship_type, field), where relationship_type is one of
324                MANY_TO_ONE, M2M_ON_RELATED_MODEL, M2M_ON_THIS_MODEL, and field
325                is the Django field object for the relationship.
326        """
327        # look for a foreign key field on related_model relating to this model
328        for field in related_model._meta.fields:
329            if self._is_relation_to(field, self.model):
330                return self.MANY_TO_ONE, field
331
332        # look for an M2M field on related_model relating to this model
333        for field in related_model._meta.many_to_many:
334            if self._is_relation_to(field, self.model):
335                return self.M2M_ON_RELATED_MODEL, field
336
337        # maybe this model has the many-to-many field
338        for field in self.model._meta.many_to_many:
339            if self._is_relation_to(field, related_model):
340                return self.M2M_ON_THIS_MODEL, field
341
342        raise ValueError('%s has no relation to %s' %
343                         (related_model, self.model))
344
345
346    def _get_pivot_iterator(self, base_objects_by_id, related_model):
347        """
348        Determine the relationship between this model and related_model, and
349        return a pivot iterator.
350        @param base_objects_by_id: dict of instances of this model indexed by
351        their IDs
352        @returns a pivot iterator, which yields a tuple (base_object,
353        related_object) for each relationship between a base object and a
354        related object.  all base_object instances come from base_objects_by_id.
355        Note -- this depends on Django model internals.
356        """
357        relationship_type, field = self.determine_relationship(related_model)
358        if relationship_type == self.MANY_TO_ONE:
359            return self._many_to_one_pivot(base_objects_by_id,
360                                           related_model, field)
361        elif relationship_type == self.M2M_ON_RELATED_MODEL:
362            return self._many_to_many_pivot(
363                    base_objects_by_id, related_model, field.m2m_db_table(),
364                    field.m2m_reverse_name(), field.m2m_column_name())
365        else:
366            assert relationship_type == self.M2M_ON_THIS_MODEL
367            return self._many_to_many_pivot(
368                    base_objects_by_id, related_model, field.m2m_db_table(),
369                    field.m2m_column_name(), field.m2m_reverse_name())
370
371
372    def _many_to_one_pivot(self, base_objects_by_id, related_model,
373                           foreign_key_field):
374        """
375        @returns a pivot iterator - see _get_pivot_iterator()
376        """
377        filter_data = {foreign_key_field.name + '__pk__in':
378                       base_objects_by_id.keys()}
379        for related_object in related_model.objects.filter(**filter_data):
380            # lookup base object in the dict, rather than grabbing it from the
381            # related object.  we need to return instances from the dict, not
382            # fresh instances of the same models (and grabbing model instances
383            # from the related models incurs a DB query each time).
384            base_object_id = getattr(related_object, foreign_key_field.attname)
385            base_object = base_objects_by_id[base_object_id]
386            yield base_object, related_object
387
388
389    def _query_pivot_table(self, base_objects_by_id, pivot_table,
390                           pivot_from_field, pivot_to_field, related_model):
391        """
392        @param id_list list of IDs of self.model objects to include
393        @param pivot_table the name of the pivot table
394        @param pivot_from_field a field name on pivot_table referencing
395        self.model
396        @param pivot_to_field a field name on pivot_table referencing the
397        related model.
398        @param related_model the related model
399
400        @returns pivot list of IDs (base_id, related_id)
401        """
402        query = """
403        SELECT %(from_field)s, %(to_field)s
404        FROM %(table)s
405        WHERE %(from_field)s IN (%(id_list)s)
406        """ % dict(from_field=pivot_from_field,
407                   to_field=pivot_to_field,
408                   table=pivot_table,
409                   id_list=','.join(str(id_) for id_
410                                    in base_objects_by_id.iterkeys()))
411
412        # Chose the connection that's responsible for this type of object
413        # The databases for related_model and the current model will always
414        # be the same, related_model is just easier to obtain here because
415        # self is only a ExtendedManager, not the object.
416        cursor = connections[related_model.objects.db].cursor()
417        cursor.execute(query)
418        return cursor.fetchall()
419
420
421    def _many_to_many_pivot(self, base_objects_by_id, related_model,
422                            pivot_table, pivot_from_field, pivot_to_field):
423        """
424        @param pivot_table: see _query_pivot_table
425        @param pivot_from_field: see _query_pivot_table
426        @param pivot_to_field: see _query_pivot_table
427        @returns a pivot iterator - see _get_pivot_iterator()
428        """
429        id_pivot = self._query_pivot_table(base_objects_by_id, pivot_table,
430                                           pivot_from_field, pivot_to_field,
431                                           related_model)
432
433        all_related_ids = list(set(related_id for base_id, related_id
434                                   in id_pivot))
435        related_objects_by_id = related_model.objects.in_bulk(all_related_ids)
436
437        for base_id, related_id in id_pivot:
438            yield base_objects_by_id[base_id], related_objects_by_id[related_id]
439
440
441    def populate_relationships(self, base_objects, related_model,
442                               related_list_name):
443        """
444        For each instance of this model in base_objects, add a field named
445        related_list_name listing all the related objects of type related_model.
446        related_model must be in a many-to-one or many-to-many relationship with
447        this model.
448        @param base_objects - list of instances of this model
449        @param related_model - model class related to this model
450        @param related_list_name - attribute name in which to store the related
451        object list.
452        """
453        if not base_objects:
454            # if we don't bail early, we'll get a SQL error later
455            return
456
457        base_objects_by_id = dict((base_object._get_pk_val(), base_object)
458                                  for base_object in base_objects)
459        pivot_iterator = self._get_pivot_iterator(base_objects_by_id,
460                                                  related_model)
461
462        for base_object in base_objects:
463            setattr(base_object, related_list_name, [])
464
465        for base_object, related_object in pivot_iterator:
466            getattr(base_object, related_list_name).append(related_object)
467
468
469class ModelWithInvalidQuerySet(dbmodels.query.QuerySet):
470    """
471    QuerySet that handles delete() properly for models with an "invalid" bit
472    """
473    def delete(self):
474        for model in self:
475            model.delete()
476
477
478class ModelWithInvalidManager(ExtendedManager):
479    """
480    Manager for objects with an "invalid" bit
481    """
482    def get_query_set(self):
483        return ModelWithInvalidQuerySet(self.model)
484
485
486class ValidObjectsManager(ModelWithInvalidManager):
487    """
488    Manager returning only objects with invalid=False.
489    """
490    def get_query_set(self):
491        queryset = super(ValidObjectsManager, self).get_query_set()
492        return queryset.filter(invalid=False)
493
494
495class ModelExtensions(rdb_model_extensions.ModelValidators):
496    """\
497    Mixin with convenience functions for models, built on top of
498    the model validators in rdb_model_extensions.
499    """
500    # TODO: at least some of these functions really belong in a custom
501    # Manager class
502
503
504    SERIALIZATION_LINKS_TO_FOLLOW = set()
505    """
506    To be able to send jobs and hosts to shards, it's necessary to find their
507    dependencies.
508    The most generic approach for this would be to traverse all relationships
509    to other objects recursively. This would list all objects that are related
510    in any way.
511    But this approach finds too many objects: If a host should be transferred,
512    all it's relationships would be traversed. This would find an acl group.
513    If then the acl group's relationships are traversed, the relationship
514    would be followed backwards and many other hosts would be found.
515
516    This mapping tells that algorithm which relations to follow explicitly.
517    """
518
519
520    SERIALIZATION_LINKS_TO_KEEP = set()
521    """This set stores foreign keys which we don't want to follow, but
522    still want to include in the serialized dictionary. For
523    example, we follow the relationship `Host.hostattribute_set`,
524    but we do not want to follow `HostAttributes.host_id` back to
525    to Host, which would otherwise lead to a circle. However, we still
526    like to serialize HostAttribute.`host_id`."""
527
528    SERIALIZATION_LOCAL_LINKS_TO_UPDATE = set()
529    """
530    On deserializion, if the object to persist already exists, local fields
531    will only be updated, if their name is in this set.
532    """
533
534
535    @classmethod
536    def convert_human_readable_values(cls, data, to_human_readable=False):
537        """\
538        Performs conversions on user-supplied field data, to make it
539        easier for users to pass human-readable data.
540
541        For all fields that have choice sets, convert their values
542        from human-readable strings to enum values, if necessary.  This
543        allows users to pass strings instead of the corresponding
544        integer values.
545
546        For all foreign key fields, call smart_get with the supplied
547        data.  This allows the user to pass either an ID value or
548        the name of the object as a string.
549
550        If to_human_readable=True, perform the inverse - i.e. convert
551        numeric values to human readable values.
552
553        This method modifies data in-place.
554        """
555        field_dict = cls.get_field_dict()
556        for field_name in data:
557            if field_name not in field_dict or data[field_name] is None:
558                continue
559            field_obj = field_dict[field_name]
560            # convert enum values
561            if field_obj.choices:
562                for choice_data in field_obj.choices:
563                    # choice_data is (value, name)
564                    if to_human_readable:
565                        from_val, to_val = choice_data
566                    else:
567                        to_val, from_val = choice_data
568                    if from_val == data[field_name]:
569                        data[field_name] = to_val
570                        break
571            # convert foreign key values
572            elif field_obj.rel:
573                dest_obj = field_obj.rel.to.smart_get(data[field_name],
574                                                      valid_only=False)
575                if to_human_readable:
576                    # parameterized_jobs do not have a name_field
577                    if (field_name != 'parameterized_job' and
578                        dest_obj.name_field is not None):
579                        data[field_name] = getattr(dest_obj,
580                                                   dest_obj.name_field)
581                else:
582                    data[field_name] = dest_obj
583
584
585
586
587    def _validate_unique(self):
588        """\
589        Validate that unique fields are unique.  Django manipulators do
590        this too, but they're a huge pain to use manually.  Trust me.
591        """
592        errors = {}
593        cls = type(self)
594        field_dict = self.get_field_dict()
595        manager = cls.get_valid_manager()
596        for field_name, field_obj in field_dict.iteritems():
597            if not field_obj.unique:
598                continue
599
600            value = getattr(self, field_name)
601            if value is None and field_obj.auto_created:
602                # don't bother checking autoincrement fields about to be
603                # generated
604                continue
605
606            existing_objs = manager.filter(**{field_name : value})
607            num_existing = existing_objs.count()
608
609            if num_existing == 0:
610                continue
611            if num_existing == 1 and existing_objs[0].id == self.id:
612                continue
613            errors[field_name] = (
614                'This value must be unique (%s)' % (value))
615        return errors
616
617
618    def _validate(self):
619        """
620        First coerces all fields on this instance to their proper Python types.
621        Then runs validation on every field. Returns a dictionary of
622        field_name -> error_list.
623
624        Based on validate() from django.db.models.Model in Django 0.96, which
625        was removed in Django 1.0. It should reappear in a later version. See:
626            http://code.djangoproject.com/ticket/6845
627        """
628        error_dict = {}
629        for f in self._meta.fields:
630            try:
631                python_value = f.to_python(
632                    getattr(self, f.attname, f.get_default()))
633            except django.core.exceptions.ValidationError, e:
634                error_dict[f.name] = str(e)
635                continue
636
637            if not f.blank and not python_value:
638                error_dict[f.name] = 'This field is required.'
639                continue
640
641            setattr(self, f.attname, python_value)
642
643        return error_dict
644
645
646    def do_validate(self):
647        errors = self._validate()
648        unique_errors = self._validate_unique()
649        for field_name, error in unique_errors.iteritems():
650            errors.setdefault(field_name, error)
651        if errors:
652            raise ValidationError(errors)
653
654
655    # actually (externally) useful methods follow
656
657    @classmethod
658    def add_object(cls, data={}, **kwargs):
659        """\
660        Returns a new object created with the given data (a dictionary
661        mapping field names to values). Merges any extra keyword args
662        into data.
663        """
664        data = dict(data)
665        data.update(kwargs)
666        data = cls.prepare_data_args(data)
667        cls.convert_human_readable_values(data)
668        data = cls.provide_default_values(data)
669
670        obj = cls(**data)
671        obj.do_validate()
672        obj.save()
673        return obj
674
675
676    def update_object(self, data={}, **kwargs):
677        """\
678        Updates the object with the given data (a dictionary mapping
679        field names to values).  Merges any extra keyword args into
680        data.
681        """
682        data = dict(data)
683        data.update(kwargs)
684        data = self.prepare_data_args(data)
685        self.convert_human_readable_values(data)
686        for field_name, value in data.iteritems():
687            setattr(self, field_name, value)
688        self.do_validate()
689        self.save()
690
691
692    # see query_objects()
693    _SPECIAL_FILTER_KEYS = ('query_start', 'query_limit', 'sort_by',
694                            'extra_args', 'extra_where', 'no_distinct')
695
696
697    @classmethod
698    def _extract_special_params(cls, filter_data):
699        """
700        @returns a tuple of dicts (special_params, regular_filters), where
701        special_params contains the parameters we handle specially and
702        regular_filters is the remaining data to be handled by Django.
703        """
704        regular_filters = dict(filter_data)
705        special_params = {}
706        for key in cls._SPECIAL_FILTER_KEYS:
707            if key in regular_filters:
708                special_params[key] = regular_filters.pop(key)
709        return special_params, regular_filters
710
711
712    @classmethod
713    def apply_presentation(cls, query, filter_data):
714        """
715        Apply presentation parameters -- sorting and paging -- to the given
716        query.
717        @returns new query with presentation applied
718        """
719        special_params, _ = cls._extract_special_params(filter_data)
720        sort_by = special_params.get('sort_by', None)
721        if sort_by:
722            assert isinstance(sort_by, list) or isinstance(sort_by, tuple)
723            query = query.extra(order_by=sort_by)
724
725        query_start = special_params.get('query_start', None)
726        query_limit = special_params.get('query_limit', None)
727        if query_start is not None:
728            if query_limit is None:
729                raise ValueError('Cannot pass query_start without query_limit')
730            # query_limit is passed as a page size
731            query_limit += query_start
732        return query[query_start:query_limit]
733
734
735    @classmethod
736    def query_objects(cls, filter_data, valid_only=True, initial_query=None,
737                      apply_presentation=True):
738        """\
739        Returns a QuerySet object for querying the given model_class
740        with the given filter_data.  Optional special arguments in
741        filter_data include:
742        -query_start: index of first return to return
743        -query_limit: maximum number of results to return
744        -sort_by: list of fields to sort on.  prefixing a '-' onto a
745         field name changes the sort to descending order.
746        -extra_args: keyword args to pass to query.extra() (see Django
747         DB layer documentation)
748        -extra_where: extra WHERE clause to append
749        -no_distinct: if True, a DISTINCT will not be added to the SELECT
750        """
751        special_params, regular_filters = cls._extract_special_params(
752                filter_data)
753
754        if initial_query is None:
755            if valid_only:
756                initial_query = cls.get_valid_manager()
757            else:
758                initial_query = cls.objects
759
760        query = initial_query.filter(**regular_filters)
761
762        use_distinct = not special_params.get('no_distinct', False)
763        if use_distinct:
764            query = query.distinct()
765
766        extra_args = special_params.get('extra_args', {})
767        extra_where = special_params.get('extra_where', None)
768        if extra_where:
769            # escape %'s
770            extra_where = cls.objects.escape_user_sql(extra_where)
771            extra_args.setdefault('where', []).append(extra_where)
772        if extra_args:
773            query = query.extra(**extra_args)
774            # TODO: Use readonly connection for these queries.
775            # This has been disabled, because it's not used anyway, as the
776            # configured readonly user is the same as the real user anyway.
777
778        if apply_presentation:
779            query = cls.apply_presentation(query, filter_data)
780
781        return query
782
783
784    @classmethod
785    def query_count(cls, filter_data, initial_query=None):
786        """\
787        Like query_objects, but retreive only the count of results.
788        """
789        filter_data.pop('query_start', None)
790        filter_data.pop('query_limit', None)
791        query = cls.query_objects(filter_data, initial_query=initial_query)
792        return query.count()
793
794
795    @classmethod
796    def clean_object_dicts(cls, field_dicts):
797        """\
798        Take a list of dicts corresponding to object (as returned by
799        query.values()) and clean the data to be more suitable for
800        returning to the user.
801        """
802        for field_dict in field_dicts:
803            cls.clean_foreign_keys(field_dict)
804            cls._convert_booleans(field_dict)
805            cls.convert_human_readable_values(field_dict,
806                                              to_human_readable=True)
807
808
809    @classmethod
810    def list_objects(cls, filter_data, initial_query=None):
811        """\
812        Like query_objects, but return a list of dictionaries.
813        """
814        query = cls.query_objects(filter_data, initial_query=initial_query)
815        extra_fields = query.query.extra_select.keys()
816        field_dicts = [model_object.get_object_dict(extra_fields=extra_fields)
817                       for model_object in query]
818        return field_dicts
819
820
821    @classmethod
822    def smart_get(cls, id_or_name, valid_only=True):
823        """\
824        smart_get(integer) -> get object by ID
825        smart_get(string) -> get object by name_field
826        """
827        if valid_only:
828            manager = cls.get_valid_manager()
829        else:
830            manager = cls.objects
831
832        if isinstance(id_or_name, (int, long)):
833            return manager.get(pk=id_or_name)
834        if isinstance(id_or_name, basestring) and hasattr(cls, 'name_field'):
835            return manager.get(**{cls.name_field : id_or_name})
836        raise ValueError(
837            'Invalid positional argument: %s (%s)' % (id_or_name,
838                                                      type(id_or_name)))
839
840
841    @classmethod
842    def smart_get_bulk(cls, id_or_name_list):
843        invalid_inputs = []
844        result_objects = []
845        for id_or_name in id_or_name_list:
846            try:
847                result_objects.append(cls.smart_get(id_or_name))
848            except cls.DoesNotExist:
849                invalid_inputs.append(id_or_name)
850        if invalid_inputs:
851            raise cls.DoesNotExist('The following %ss do not exist: %s'
852                                   % (cls.__name__.lower(),
853                                      ', '.join(invalid_inputs)))
854        return result_objects
855
856
857    def get_object_dict(self, extra_fields=None):
858        """\
859        Return a dictionary mapping fields to this object's values.  @param
860        extra_fields: list of extra attribute names to include, in addition to
861        the fields defined on this object.
862        """
863        fields = self.get_field_dict().keys()
864        if extra_fields:
865            fields += extra_fields
866        object_dict = dict((field_name, getattr(self, field_name))
867                           for field_name in fields)
868        self.clean_object_dicts([object_dict])
869        self._postprocess_object_dict(object_dict)
870        return object_dict
871
872
873    def _postprocess_object_dict(self, object_dict):
874        """For subclasses to override."""
875        pass
876
877
878    @classmethod
879    def get_valid_manager(cls):
880        return cls.objects
881
882
883    def _record_attributes(self, attributes):
884        """
885        See on_attribute_changed.
886        """
887        assert not isinstance(attributes, basestring)
888        self._recorded_attributes = dict((attribute, getattr(self, attribute))
889                                         for attribute in attributes)
890
891
892    def _check_for_updated_attributes(self):
893        """
894        See on_attribute_changed.
895        """
896        for attribute, original_value in self._recorded_attributes.iteritems():
897            new_value = getattr(self, attribute)
898            if original_value != new_value:
899                self.on_attribute_changed(attribute, original_value)
900        self._record_attributes(self._recorded_attributes.keys())
901
902
903    def on_attribute_changed(self, attribute, old_value):
904        """
905        Called whenever an attribute is updated.  To be overridden.
906
907        To use this method, you must:
908        * call _record_attributes() from __init__() (after making the super
909        call) with a list of attributes for which you want to be notified upon
910        change.
911        * call _check_for_updated_attributes() from save().
912        """
913        pass
914
915
916    def serialize(self, include_dependencies=True):
917        """Serializes the object with dependencies.
918
919        The variable SERIALIZATION_LINKS_TO_FOLLOW defines which dependencies
920        this function will serialize with the object.
921
922        @param include_dependencies: Whether or not to follow relations to
923                                     objects this object depends on.
924                                     This parameter is used when uploading
925                                     jobs from a shard to the master, as the
926                                     master already has all the dependent
927                                     objects.
928
929        @returns: Dictionary representation of the object.
930        """
931        serialized = {}
932        for field in self._meta.concrete_model._meta.local_fields:
933            if field.rel is None:
934                serialized[field.name] = field._get_val_from_obj(self)
935            elif field.name in self.SERIALIZATION_LINKS_TO_KEEP:
936                # attname will contain "_id" suffix for foreign keys,
937                # e.g. HostAttribute.host will be serialized as 'host_id'.
938                # Use it for easy deserialization.
939                serialized[field.attname] = field._get_val_from_obj(self)
940
941        if include_dependencies:
942            for link in self.SERIALIZATION_LINKS_TO_FOLLOW:
943                serialized[link] = self._serialize_relation(link)
944
945        return serialized
946
947
948    def _serialize_relation(self, link):
949        """Serializes dependent objects given the name of the relation.
950
951        @param link: Name of the relation to take objects from.
952
953        @returns For To-Many relationships a list of the serialized related
954            objects, for To-One relationships the serialized related object.
955        """
956        try:
957            attr = getattr(self, link)
958        except AttributeError:
959            # One-To-One relationships that point to None may raise this
960            return None
961
962        if attr is None:
963            return None
964        if hasattr(attr, 'all'):
965            return [obj.serialize() for obj in attr.all()]
966        return attr.serialize()
967
968
969    @classmethod
970    def _split_local_from_foreign_values(cls, data):
971        """This splits local from foreign values in a serialized object.
972
973        @param data: The serialized object.
974
975        @returns A tuple of two lists, both containing tuples in the form
976                 (link_name, link_value). The first list contains all links
977                 for local fields, the second one contains those for foreign
978                 fields/objects.
979        """
980        links_to_local_values, links_to_related_values = [], []
981        for link, value in data.iteritems():
982            if link in cls.SERIALIZATION_LINKS_TO_FOLLOW:
983                # It's a foreign key
984                links_to_related_values.append((link, value))
985            else:
986                # It's a local attribute or a foreign key
987                # we don't want to follow.
988                links_to_local_values.append((link, value))
989        return links_to_local_values, links_to_related_values
990
991
992    @classmethod
993    def _filter_update_allowed_fields(cls, data):
994        """Filters data and returns only files that updates are allowed on.
995
996        This is i.e. needed for syncing aborted bits from the master to shards.
997
998        Local links are only allowed to be updated, if they are in
999        SERIALIZATION_LOCAL_LINKS_TO_UPDATE.
1000        Overwriting existing values is allowed in order to be able to sync i.e.
1001        the aborted bit from the master to a shard.
1002
1003        The whitelisting mechanism is in place to prevent overwriting local
1004        status: If all fields were overwritten, jobs would be completely be
1005        set back to their original (unstarted) state.
1006
1007        @param data: List with tuples of the form (link_name, link_value), as
1008                     returned by _split_local_from_foreign_values.
1009
1010        @returns List of the same format as data, but only containing data for
1011                 fields that updates are allowed on.
1012        """
1013        return [pair for pair in data
1014                if pair[0] in cls.SERIALIZATION_LOCAL_LINKS_TO_UPDATE]
1015
1016
1017    @classmethod
1018    def delete_matching_record(cls, **filter_args):
1019        """Delete records matching the filter.
1020
1021        @param filter_args: Arguments for the django filter
1022                used to locate the record to delete.
1023        """
1024        try:
1025            existing_record = cls.objects.get(**filter_args)
1026        except cls.DoesNotExist:
1027            return
1028        existing_record.delete()
1029
1030
1031    def _deserialize_local(self, data):
1032        """Set local attributes from a list of tuples.
1033
1034        @param data: List of tuples like returned by
1035                     _split_local_from_foreign_values.
1036        """
1037        if not data:
1038            return
1039
1040        for link, value in data:
1041            setattr(self, link, value)
1042        # Overwridden save() methods are prone to errors, so don't execute them.
1043        # This is because:
1044        # - the overwritten methods depend on ACL groups that don't yet exist
1045        #   and don't handle errors
1046        # - the overwritten methods think this object already exists in the db
1047        #   because the id is already set
1048        super(type(self), self).save()
1049
1050
1051    def _deserialize_relations(self, data):
1052        """Set foreign attributes from a list of tuples.
1053
1054        This deserialized the related objects using their own deserialize()
1055        function and then sets the relation.
1056
1057        @param data: List of tuples like returned by
1058                     _split_local_from_foreign_values.
1059        """
1060        for link, value in data:
1061            self._deserialize_relation(link, value)
1062        # See comment in _deserialize_local
1063        super(type(self), self).save()
1064
1065
1066    @classmethod
1067    def get_record(cls, data):
1068        """Retrieve a record with the data in the given input arg.
1069
1070        @param data: A dictionary containing the information to use in a query
1071                for data. If child models have different constraints of
1072                uniqueness they should override this model.
1073
1074        @return: An object with matching data.
1075
1076        @raises DoesNotExist: If a record with the given data doesn't exist.
1077        """
1078        return cls.objects.get(id=data['id'])
1079
1080
1081    @classmethod
1082    def deserialize(cls, data):
1083        """Recursively deserializes and saves an object with it's dependencies.
1084
1085        This takes the result of the serialize method and creates objects
1086        in the database that are just like the original.
1087
1088        If an object of the same type with the same id already exists, it's
1089        local values will be left untouched, unless they are explicitly
1090        whitelisted in SERIALIZATION_LOCAL_LINKS_TO_UPDATE.
1091
1092        Deserialize will always recursively propagate to all related objects
1093        present in data though.
1094        I.e. this is necessary to add users to an already existing acl-group.
1095
1096        @param data: Representation of an object and its dependencies, as
1097                     returned by serialize.
1098
1099        @returns: The object represented by data if it didn't exist before,
1100                  otherwise the object that existed before and has the same type
1101                  and id as the one described by data.
1102        """
1103        if data is None:
1104            return None
1105
1106        local, related = cls._split_local_from_foreign_values(data)
1107        try:
1108            instance = cls.get_record(data)
1109            local = cls._filter_update_allowed_fields(local)
1110        except cls.DoesNotExist:
1111            instance = cls()
1112
1113        instance._deserialize_local(local)
1114        instance._deserialize_relations(related)
1115
1116        return instance
1117
1118
1119    def sanity_check_update_from_shard(self, shard, updated_serialized,
1120                                       *args, **kwargs):
1121        """Check if an update sent from a shard is legitimate.
1122
1123        @raises error.UnallowedRecordsSentToMaster if an update is not
1124                legitimate.
1125        """
1126        raise NotImplementedError(
1127            'sanity_check_update_from_shard must be implemented by subclass %s '
1128            'for type %s' % type(self))
1129
1130
1131    @transaction.commit_on_success
1132    def update_from_serialized(self, serialized):
1133        """Updates local fields of an existing object from a serialized form.
1134
1135        This is different than the normal deserialize() in the way that it
1136        does update local values, which deserialize doesn't, but doesn't
1137        recursively propagate to related objects, which deserialize() does.
1138
1139        The use case of this function is to update job records on the master
1140        after the jobs have been executed on a slave, as the master is not
1141        interested in updates for users, labels, specialtasks, etc.
1142
1143        @param serialized: Representation of an object and its dependencies, as
1144                           returned by serialize.
1145
1146        @raises ValueError: if serialized contains related objects, i.e. not
1147                            only local fields.
1148        """
1149        local, related = (
1150            self._split_local_from_foreign_values(serialized))
1151        if related:
1152            raise ValueError('Serialized must not contain foreign '
1153                             'objects: %s' % related)
1154
1155        self._deserialize_local(local)
1156
1157
1158    def custom_deserialize_relation(self, link, data):
1159        """Allows overriding the deserialization behaviour by subclasses."""
1160        raise NotImplementedError(
1161            'custom_deserialize_relation must be implemented by subclass %s '
1162            'for relation %s' % (type(self), link))
1163
1164
1165    def _deserialize_relation(self, link, data):
1166        """Deserializes related objects and sets references on this object.
1167
1168        Relations that point to a list of objects are handled automatically.
1169        For many-to-one or one-to-one relations custom_deserialize_relation
1170        must be overridden by the subclass.
1171
1172        Related objects are deserialized using their deserialize() method.
1173        Thereby they and their dependencies are created if they don't exist
1174        and saved to the database.
1175
1176        @param link: Name of the relation.
1177        @param data: Serialized representation of the related object(s).
1178                     This means a list of dictionaries for to-many relations,
1179                     just a dictionary for to-one relations.
1180        """
1181        field = getattr(self, link)
1182
1183        if field and hasattr(field, 'all'):
1184            self._deserialize_2m_relation(link, data, field.model)
1185        else:
1186            self.custom_deserialize_relation(link, data)
1187
1188
1189    def _deserialize_2m_relation(self, link, data, related_class):
1190        """Deserialize related objects for one to-many relationship.
1191
1192        @param link: Name of the relation.
1193        @param data: Serialized representation of the related objects.
1194                     This is a list with of dictionaries.
1195        @param related_class: A class representing a django model, with which
1196                              this class has a one-to-many relationship.
1197        """
1198        relation_set = getattr(self, link)
1199        if related_class == self.get_attribute_model():
1200            # When deserializing a model together with
1201            # its attributes, clear all the exising attributes to ensure
1202            # db consistency. Note 'update' won't be sufficient, as we also
1203            # want to remove any attributes that no longer exist in |data|.
1204            #
1205            # core_filters is a dictionary of filters, defines how
1206            # RelatedMangager would query for the 1-to-many relationship. E.g.
1207            # Host.objects.get(
1208            #     id=20).hostattribute_set.core_filters = {host_id:20}
1209            # We use it to delete objects related to the current object.
1210            related_class.objects.filter(**relation_set.core_filters).delete()
1211        for serialized in data:
1212            relation_set.add(related_class.deserialize(serialized))
1213
1214
1215    @classmethod
1216    def get_attribute_model(cls):
1217        """Return the attribute model.
1218
1219        Subclass with attribute-like model should override this to
1220        return the attribute model class. This method will be
1221        called by _deserialize_2m_relation to determine whether
1222        to clear the one-to-many relations first on deserialization of object.
1223        """
1224        return None
1225
1226
1227class ModelWithInvalid(ModelExtensions):
1228    """
1229    Overrides model methods save() and delete() to support invalidation in
1230    place of actual deletion.  Subclasses must have a boolean "invalid"
1231    field.
1232    """
1233
1234    def save(self, *args, **kwargs):
1235        first_time = (self.id is None)
1236        if first_time:
1237            # see if this object was previously added and invalidated
1238            my_name = getattr(self, self.name_field)
1239            filters = {self.name_field : my_name, 'invalid' : True}
1240            try:
1241                old_object = self.__class__.objects.get(**filters)
1242                self.resurrect_object(old_object)
1243            except self.DoesNotExist:
1244                # no existing object
1245                pass
1246
1247        super(ModelWithInvalid, self).save(*args, **kwargs)
1248
1249
1250    def resurrect_object(self, old_object):
1251        """
1252        Called when self is about to be saved for the first time and is actually
1253        "undeleting" a previously deleted object.  Can be overridden by
1254        subclasses to copy data as desired from the deleted entry (but this
1255        superclass implementation must normally be called).
1256        """
1257        self.id = old_object.id
1258
1259
1260    def clean_object(self):
1261        """
1262        This method is called when an object is marked invalid.
1263        Subclasses should override this to clean up relationships that
1264        should no longer exist if the object were deleted.
1265        """
1266        pass
1267
1268
1269    def delete(self):
1270        self.invalid = self.invalid
1271        assert not self.invalid
1272        self.invalid = True
1273        self.save()
1274        self.clean_object()
1275
1276
1277    @classmethod
1278    def get_valid_manager(cls):
1279        return cls.valid_objects
1280
1281
1282    class Manipulator(object):
1283        """
1284        Force default manipulators to look only at valid objects -
1285        otherwise they will match against invalid objects when checking
1286        uniqueness.
1287        """
1288        @classmethod
1289        def _prepare(cls, model):
1290            super(ModelWithInvalid.Manipulator, cls)._prepare(model)
1291            cls.manager = model.valid_objects
1292
1293
1294class ModelWithAttributes(object):
1295    """
1296    Mixin class for models that have an attribute model associated with them.
1297    The attribute model is assumed to have its value field named "value".
1298    """
1299
1300    def _get_attribute_model_and_args(self, attribute):
1301        """
1302        Subclasses should override this to return a tuple (attribute_model,
1303        keyword_args), where attribute_model is a model class and keyword_args
1304        is a dict of args to pass to attribute_model.objects.get() to get an
1305        instance of the given attribute on this object.
1306        """
1307        raise NotImplementedError
1308
1309
1310    def set_attribute(self, attribute, value):
1311        attribute_model, get_args = self._get_attribute_model_and_args(
1312            attribute)
1313        attribute_object, _ = attribute_model.objects.get_or_create(**get_args)
1314        attribute_object.value = value
1315        attribute_object.save()
1316
1317
1318    def delete_attribute(self, attribute):
1319        attribute_model, get_args = self._get_attribute_model_and_args(
1320            attribute)
1321        try:
1322            attribute_model.objects.get(**get_args).delete()
1323        except attribute_model.DoesNotExist:
1324            pass
1325
1326
1327    def set_or_delete_attribute(self, attribute, value):
1328        if value is None:
1329            self.delete_attribute(attribute)
1330        else:
1331            self.set_attribute(attribute, value)
1332
1333
1334class ModelWithHashManager(dbmodels.Manager):
1335    """Manager for use with the ModelWithHash abstract model class"""
1336
1337    def create(self, **kwargs):
1338        raise Exception('ModelWithHash manager should use get_or_create() '
1339                        'instead of create()')
1340
1341
1342    def get_or_create(self, **kwargs):
1343        kwargs['the_hash'] = self.model._compute_hash(**kwargs)
1344        return super(ModelWithHashManager, self).get_or_create(**kwargs)
1345
1346
1347class ModelWithHash(dbmodels.Model):
1348    """Superclass with methods for dealing with a hash column"""
1349
1350    the_hash = dbmodels.CharField(max_length=40, unique=True)
1351
1352    objects = ModelWithHashManager()
1353
1354    class Meta:
1355        abstract = True
1356
1357
1358    @classmethod
1359    def _compute_hash(cls, **kwargs):
1360        raise NotImplementedError('Subclasses must override _compute_hash()')
1361
1362
1363    def save(self, force_insert=False, **kwargs):
1364        """Prevents saving the model in most cases
1365
1366        We want these models to be immutable, so the generic save() operation
1367        will not work. These models should be instantiated through their the
1368        model.objects.get_or_create() method instead.
1369
1370        The exception is that save(force_insert=True) will be allowed, since
1371        that creates a new row. However, the preferred way to make instances of
1372        these models is through the get_or_create() method.
1373        """
1374        if not force_insert:
1375            # Allow a forced insert to happen; if it's a duplicate, the unique
1376            # constraint will catch it later anyways
1377            raise Exception('ModelWithHash is immutable')
1378        super(ModelWithHash, self).save(force_insert=force_insert, **kwargs)
1379