1""" 2Extensions to Django's model logic. 3""" 4 5import django.core.exceptions 6from django.db import connection 7from django.db import connections 8from django.db import models as dbmodels 9from django.db import transaction 10from django.db.models.sql import query 11import django.db.models.sql.where 12from autotest_lib.frontend.afe import rdb_model_extensions 13 14 15class ValidationError(django.core.exceptions.ValidationError): 16 """\ 17 Data validation error in adding or updating an object. The associated 18 value is a dictionary mapping field names to error strings. 19 """ 20 21def _quote_name(name): 22 """Shorthand for connection.ops.quote_name().""" 23 return connection.ops.quote_name(name) 24 25 26class LeasedHostManager(dbmodels.Manager): 27 """Query manager for unleased, unlocked hosts. 28 """ 29 def get_query_set(self): 30 return (super(LeasedHostManager, self).get_query_set().filter( 31 leased=0, locked=0)) 32 33 34class ExtendedManager(dbmodels.Manager): 35 """\ 36 Extended manager supporting subquery filtering. 37 """ 38 39 class CustomQuery(query.Query): 40 def __init__(self, *args, **kwargs): 41 super(ExtendedManager.CustomQuery, self).__init__(*args, **kwargs) 42 self._custom_joins = [] 43 44 45 def clone(self, klass=None, **kwargs): 46 obj = super(ExtendedManager.CustomQuery, self).clone(klass) 47 obj._custom_joins = list(self._custom_joins) 48 return obj 49 50 51 def combine(self, rhs, connector): 52 super(ExtendedManager.CustomQuery, self).combine(rhs, connector) 53 if hasattr(rhs, '_custom_joins'): 54 self._custom_joins.extend(rhs._custom_joins) 55 56 57 def add_custom_join(self, table, condition, join_type, 58 condition_values=(), alias=None): 59 if alias is None: 60 alias = table 61 join_dict = dict(table=table, 62 condition=condition, 63 condition_values=condition_values, 64 join_type=join_type, 65 alias=alias) 66 self._custom_joins.append(join_dict) 67 68 69 @classmethod 70 def convert_query(self, query_set): 71 """ 72 Convert the query set's "query" attribute to a CustomQuery. 73 """ 74 # Make a copy of the query set 75 query_set = query_set.all() 76 query_set.query = query_set.query.clone( 77 klass=ExtendedManager.CustomQuery, 78 _custom_joins=[]) 79 return query_set 80 81 82 class _WhereClause(object): 83 """Object allowing us to inject arbitrary SQL into Django queries. 84 85 By using this instead of extra(where=...), we can still freely combine 86 queries with & and |. 87 """ 88 def __init__(self, clause, values=()): 89 self._clause = clause 90 self._values = values 91 92 93 def as_sql(self, qn=None, connection=None): 94 return self._clause, self._values 95 96 97 def relabel_aliases(self, change_map): 98 return 99 100 101 def add_join(self, query_set, join_table, join_key, join_condition='', 102 join_condition_values=(), join_from_key=None, alias=None, 103 suffix='', exclude=False, force_left_join=False): 104 """Add a join to query_set. 105 106 Join looks like this: 107 (INNER|LEFT) JOIN <join_table> AS <alias> 108 ON (<this table>.<join_from_key> = <join_table>.<join_key> 109 and <join_condition>) 110 111 @param join_table table to join to 112 @param join_key field referencing back to this model to use for the join 113 @param join_condition extra condition for the ON clause of the join 114 @param join_condition_values values to substitute into join_condition 115 @param join_from_key column on this model to join from. 116 @param alias alias to use for for join 117 @param suffix suffix to add to join_table for the join alias, if no 118 alias is provided 119 @param exclude if true, exclude rows that match this join (will use a 120 LEFT OUTER JOIN and an appropriate WHERE condition) 121 @param force_left_join - if true, a LEFT OUTER JOIN will be used 122 instead of an INNER JOIN regardless of other options 123 """ 124 join_from_table = query_set.model._meta.db_table 125 if join_from_key is None: 126 join_from_key = self.model._meta.pk.name 127 if alias is None: 128 alias = join_table + suffix 129 full_join_key = _quote_name(alias) + '.' + _quote_name(join_key) 130 full_join_condition = '%s = %s.%s' % (full_join_key, 131 _quote_name(join_from_table), 132 _quote_name(join_from_key)) 133 if join_condition: 134 full_join_condition += ' AND (' + join_condition + ')' 135 if exclude or force_left_join: 136 join_type = query_set.query.LOUTER 137 else: 138 join_type = query_set.query.INNER 139 140 query_set = self.CustomQuery.convert_query(query_set) 141 query_set.query.add_custom_join(join_table, 142 full_join_condition, 143 join_type, 144 condition_values=join_condition_values, 145 alias=alias) 146 147 if exclude: 148 query_set = query_set.extra(where=[full_join_key + ' IS NULL']) 149 150 return query_set 151 152 153 def _info_for_many_to_one_join(self, field, join_to_query, alias): 154 """ 155 @param field: the ForeignKey field on the related model 156 @param join_to_query: the query over the related model that we're 157 joining to 158 @param alias: alias of joined table 159 """ 160 info = {} 161 rhs_table = join_to_query.model._meta.db_table 162 info['rhs_table'] = rhs_table 163 info['rhs_column'] = field.column 164 info['lhs_column'] = field.rel.get_related_field().column 165 rhs_where = join_to_query.query.where 166 rhs_where.relabel_aliases({rhs_table: alias}) 167 compiler = join_to_query.query.get_compiler(using=join_to_query.db) 168 initial_clause, values = compiler.as_sql() 169 # initial_clause is compiled from `join_to_query`, which is a SELECT 170 # query returns at most one record. For it to be used in WHERE clause, 171 # it must be converted to a boolean value using EXISTS. 172 all_clauses = ('EXISTS (%s)' % initial_clause,) 173 if hasattr(join_to_query.query, 'extra_where'): 174 all_clauses += join_to_query.query.extra_where 175 info['where_clause'] = ( 176 ' AND '.join('(%s)' % clause for clause in all_clauses)) 177 info['values'] = values 178 return info 179 180 181 def _info_for_many_to_many_join(self, m2m_field, join_to_query, alias, 182 m2m_is_on_this_model): 183 """ 184 @param m2m_field: a Django field representing the M2M relationship. 185 It uses a pivot table with the following structure: 186 this model table <---> M2M pivot table <---> joined model table 187 @param join_to_query: the query over the related model that we're 188 joining to. 189 @param alias: alias of joined table 190 """ 191 if m2m_is_on_this_model: 192 # referenced field on this model 193 lhs_id_field = self.model._meta.pk 194 # foreign key on the pivot table referencing lhs_id_field 195 m2m_lhs_column = m2m_field.m2m_column_name() 196 # foreign key on the pivot table referencing rhd_id_field 197 m2m_rhs_column = m2m_field.m2m_reverse_name() 198 # referenced field on related model 199 rhs_id_field = m2m_field.rel.get_related_field() 200 else: 201 lhs_id_field = m2m_field.rel.get_related_field() 202 m2m_lhs_column = m2m_field.m2m_reverse_name() 203 m2m_rhs_column = m2m_field.m2m_column_name() 204 rhs_id_field = join_to_query.model._meta.pk 205 206 info = {} 207 info['rhs_table'] = m2m_field.m2m_db_table() 208 info['rhs_column'] = m2m_lhs_column 209 info['lhs_column'] = lhs_id_field.column 210 211 # select the ID of related models relevant to this join. we can only do 212 # a single join, so we need to gather this information up front and 213 # include it in the join condition. 214 rhs_ids = join_to_query.values_list(rhs_id_field.attname, flat=True) 215 assert len(rhs_ids) == 1, ('Many-to-many custom field joins can only ' 216 'match a single related object.') 217 rhs_id = rhs_ids[0] 218 219 info['where_clause'] = '%s.%s = %s' % (_quote_name(alias), 220 _quote_name(m2m_rhs_column), 221 rhs_id) 222 info['values'] = () 223 return info 224 225 226 def join_custom_field(self, query_set, join_to_query, alias, 227 left_join=True): 228 """Join to a related model to create a custom field in the given query. 229 230 This method is used to construct a custom field on the given query based 231 on a many-valued relationsip. join_to_query should be a simple query 232 (no joins) on the related model which returns at most one related row 233 per instance of this model. 234 235 For many-to-one relationships, the joined table contains the matching 236 row from the related model it one is related, NULL otherwise. 237 238 For many-to-many relationships, the joined table contains the matching 239 row if it's related, NULL otherwise. 240 """ 241 relationship_type, field = self.determine_relationship( 242 join_to_query.model) 243 244 if relationship_type == self.MANY_TO_ONE: 245 info = self._info_for_many_to_one_join(field, join_to_query, alias) 246 elif relationship_type == self.M2M_ON_RELATED_MODEL: 247 info = self._info_for_many_to_many_join( 248 m2m_field=field, join_to_query=join_to_query, alias=alias, 249 m2m_is_on_this_model=False) 250 elif relationship_type ==self.M2M_ON_THIS_MODEL: 251 info = self._info_for_many_to_many_join( 252 m2m_field=field, join_to_query=join_to_query, alias=alias, 253 m2m_is_on_this_model=True) 254 255 return self.add_join(query_set, info['rhs_table'], info['rhs_column'], 256 join_from_key=info['lhs_column'], 257 join_condition=info['where_clause'], 258 join_condition_values=info['values'], 259 alias=alias, 260 force_left_join=left_join) 261 262 263 def add_where(self, query_set, where, values=()): 264 query_set = query_set.all() 265 query_set.query.where.add(self._WhereClause(where, values), 266 django.db.models.sql.where.AND) 267 return query_set 268 269 270 def _get_quoted_field(self, table, field): 271 return _quote_name(table) + '.' + _quote_name(field) 272 273 274 def get_key_on_this_table(self, key_field=None): 275 if key_field is None: 276 # default to primary key 277 key_field = self.model._meta.pk.column 278 return self._get_quoted_field(self.model._meta.db_table, key_field) 279 280 281 def escape_user_sql(self, sql): 282 return sql.replace('%', '%%') 283 284 285 def _custom_select_query(self, query_set, selects): 286 """Execute a custom select query. 287 288 @param query_set: query set as returned by query_objects. 289 @param selects: Tables/Columns to select, e.g. tko_test_labels_list.id. 290 291 @returns: Result of the query as returned by cursor.fetchall(). 292 """ 293 compiler = query_set.query.get_compiler(using=query_set.db) 294 sql, params = compiler.as_sql() 295 from_ = sql[sql.find(' FROM'):] 296 297 if query_set.query.distinct: 298 distinct = 'DISTINCT ' 299 else: 300 distinct = '' 301 302 sql_query = ('SELECT ' + distinct + ','.join(selects) + from_) 303 # Chose the connection that's responsible for this type of object 304 cursor = connections[query_set.db].cursor() 305 cursor.execute(sql_query, params) 306 return cursor.fetchall() 307 308 309 def _is_relation_to(self, field, model_class): 310 return field.rel and field.rel.to is model_class 311 312 313 MANY_TO_ONE = object() 314 M2M_ON_RELATED_MODEL = object() 315 M2M_ON_THIS_MODEL = object() 316 317 def determine_relationship(self, related_model): 318 """ 319 Determine the relationship between this model and related_model. 320 321 related_model must have some sort of many-valued relationship to this 322 manager's model. 323 @returns (relationship_type, field), where relationship_type is one of 324 MANY_TO_ONE, M2M_ON_RELATED_MODEL, M2M_ON_THIS_MODEL, and field 325 is the Django field object for the relationship. 326 """ 327 # look for a foreign key field on related_model relating to this model 328 for field in related_model._meta.fields: 329 if self._is_relation_to(field, self.model): 330 return self.MANY_TO_ONE, field 331 332 # look for an M2M field on related_model relating to this model 333 for field in related_model._meta.many_to_many: 334 if self._is_relation_to(field, self.model): 335 return self.M2M_ON_RELATED_MODEL, field 336 337 # maybe this model has the many-to-many field 338 for field in self.model._meta.many_to_many: 339 if self._is_relation_to(field, related_model): 340 return self.M2M_ON_THIS_MODEL, field 341 342 raise ValueError('%s has no relation to %s' % 343 (related_model, self.model)) 344 345 346 def _get_pivot_iterator(self, base_objects_by_id, related_model): 347 """ 348 Determine the relationship between this model and related_model, and 349 return a pivot iterator. 350 @param base_objects_by_id: dict of instances of this model indexed by 351 their IDs 352 @returns a pivot iterator, which yields a tuple (base_object, 353 related_object) for each relationship between a base object and a 354 related object. all base_object instances come from base_objects_by_id. 355 Note -- this depends on Django model internals. 356 """ 357 relationship_type, field = self.determine_relationship(related_model) 358 if relationship_type == self.MANY_TO_ONE: 359 return self._many_to_one_pivot(base_objects_by_id, 360 related_model, field) 361 elif relationship_type == self.M2M_ON_RELATED_MODEL: 362 return self._many_to_many_pivot( 363 base_objects_by_id, related_model, field.m2m_db_table(), 364 field.m2m_reverse_name(), field.m2m_column_name()) 365 else: 366 assert relationship_type == self.M2M_ON_THIS_MODEL 367 return self._many_to_many_pivot( 368 base_objects_by_id, related_model, field.m2m_db_table(), 369 field.m2m_column_name(), field.m2m_reverse_name()) 370 371 372 def _many_to_one_pivot(self, base_objects_by_id, related_model, 373 foreign_key_field): 374 """ 375 @returns a pivot iterator - see _get_pivot_iterator() 376 """ 377 filter_data = {foreign_key_field.name + '__pk__in': 378 base_objects_by_id.keys()} 379 for related_object in related_model.objects.filter(**filter_data): 380 # lookup base object in the dict, rather than grabbing it from the 381 # related object. we need to return instances from the dict, not 382 # fresh instances of the same models (and grabbing model instances 383 # from the related models incurs a DB query each time). 384 base_object_id = getattr(related_object, foreign_key_field.attname) 385 base_object = base_objects_by_id[base_object_id] 386 yield base_object, related_object 387 388 389 def _query_pivot_table(self, base_objects_by_id, pivot_table, 390 pivot_from_field, pivot_to_field, related_model): 391 """ 392 @param id_list list of IDs of self.model objects to include 393 @param pivot_table the name of the pivot table 394 @param pivot_from_field a field name on pivot_table referencing 395 self.model 396 @param pivot_to_field a field name on pivot_table referencing the 397 related model. 398 @param related_model the related model 399 400 @returns pivot list of IDs (base_id, related_id) 401 """ 402 query = """ 403 SELECT %(from_field)s, %(to_field)s 404 FROM %(table)s 405 WHERE %(from_field)s IN (%(id_list)s) 406 """ % dict(from_field=pivot_from_field, 407 to_field=pivot_to_field, 408 table=pivot_table, 409 id_list=','.join(str(id_) for id_ 410 in base_objects_by_id.iterkeys())) 411 412 # Chose the connection that's responsible for this type of object 413 # The databases for related_model and the current model will always 414 # be the same, related_model is just easier to obtain here because 415 # self is only a ExtendedManager, not the object. 416 cursor = connections[related_model.objects.db].cursor() 417 cursor.execute(query) 418 return cursor.fetchall() 419 420 421 def _many_to_many_pivot(self, base_objects_by_id, related_model, 422 pivot_table, pivot_from_field, pivot_to_field): 423 """ 424 @param pivot_table: see _query_pivot_table 425 @param pivot_from_field: see _query_pivot_table 426 @param pivot_to_field: see _query_pivot_table 427 @returns a pivot iterator - see _get_pivot_iterator() 428 """ 429 id_pivot = self._query_pivot_table(base_objects_by_id, pivot_table, 430 pivot_from_field, pivot_to_field, 431 related_model) 432 433 all_related_ids = list(set(related_id for base_id, related_id 434 in id_pivot)) 435 related_objects_by_id = related_model.objects.in_bulk(all_related_ids) 436 437 for base_id, related_id in id_pivot: 438 yield base_objects_by_id[base_id], related_objects_by_id[related_id] 439 440 441 def populate_relationships(self, base_objects, related_model, 442 related_list_name): 443 """ 444 For each instance of this model in base_objects, add a field named 445 related_list_name listing all the related objects of type related_model. 446 related_model must be in a many-to-one or many-to-many relationship with 447 this model. 448 @param base_objects - list of instances of this model 449 @param related_model - model class related to this model 450 @param related_list_name - attribute name in which to store the related 451 object list. 452 """ 453 if not base_objects: 454 # if we don't bail early, we'll get a SQL error later 455 return 456 457 base_objects_by_id = dict((base_object._get_pk_val(), base_object) 458 for base_object in base_objects) 459 pivot_iterator = self._get_pivot_iterator(base_objects_by_id, 460 related_model) 461 462 for base_object in base_objects: 463 setattr(base_object, related_list_name, []) 464 465 for base_object, related_object in pivot_iterator: 466 getattr(base_object, related_list_name).append(related_object) 467 468 469class ModelWithInvalidQuerySet(dbmodels.query.QuerySet): 470 """ 471 QuerySet that handles delete() properly for models with an "invalid" bit 472 """ 473 def delete(self): 474 for model in self: 475 model.delete() 476 477 478class ModelWithInvalidManager(ExtendedManager): 479 """ 480 Manager for objects with an "invalid" bit 481 """ 482 def get_query_set(self): 483 return ModelWithInvalidQuerySet(self.model) 484 485 486class ValidObjectsManager(ModelWithInvalidManager): 487 """ 488 Manager returning only objects with invalid=False. 489 """ 490 def get_query_set(self): 491 queryset = super(ValidObjectsManager, self).get_query_set() 492 return queryset.filter(invalid=False) 493 494 495class ModelExtensions(rdb_model_extensions.ModelValidators): 496 """\ 497 Mixin with convenience functions for models, built on top of 498 the model validators in rdb_model_extensions. 499 """ 500 # TODO: at least some of these functions really belong in a custom 501 # Manager class 502 503 504 SERIALIZATION_LINKS_TO_FOLLOW = set() 505 """ 506 To be able to send jobs and hosts to shards, it's necessary to find their 507 dependencies. 508 The most generic approach for this would be to traverse all relationships 509 to other objects recursively. This would list all objects that are related 510 in any way. 511 But this approach finds too many objects: If a host should be transferred, 512 all it's relationships would be traversed. This would find an acl group. 513 If then the acl group's relationships are traversed, the relationship 514 would be followed backwards and many other hosts would be found. 515 516 This mapping tells that algorithm which relations to follow explicitly. 517 """ 518 519 520 SERIALIZATION_LINKS_TO_KEEP = set() 521 """This set stores foreign keys which we don't want to follow, but 522 still want to include in the serialized dictionary. For 523 example, we follow the relationship `Host.hostattribute_set`, 524 but we do not want to follow `HostAttributes.host_id` back to 525 to Host, which would otherwise lead to a circle. However, we still 526 like to serialize HostAttribute.`host_id`.""" 527 528 SERIALIZATION_LOCAL_LINKS_TO_UPDATE = set() 529 """ 530 On deserializion, if the object to persist already exists, local fields 531 will only be updated, if their name is in this set. 532 """ 533 534 535 @classmethod 536 def convert_human_readable_values(cls, data, to_human_readable=False): 537 """\ 538 Performs conversions on user-supplied field data, to make it 539 easier for users to pass human-readable data. 540 541 For all fields that have choice sets, convert their values 542 from human-readable strings to enum values, if necessary. This 543 allows users to pass strings instead of the corresponding 544 integer values. 545 546 For all foreign key fields, call smart_get with the supplied 547 data. This allows the user to pass either an ID value or 548 the name of the object as a string. 549 550 If to_human_readable=True, perform the inverse - i.e. convert 551 numeric values to human readable values. 552 553 This method modifies data in-place. 554 """ 555 field_dict = cls.get_field_dict() 556 for field_name in data: 557 if field_name not in field_dict or data[field_name] is None: 558 continue 559 field_obj = field_dict[field_name] 560 # convert enum values 561 if field_obj.choices: 562 for choice_data in field_obj.choices: 563 # choice_data is (value, name) 564 if to_human_readable: 565 from_val, to_val = choice_data 566 else: 567 to_val, from_val = choice_data 568 if from_val == data[field_name]: 569 data[field_name] = to_val 570 break 571 # convert foreign key values 572 elif field_obj.rel: 573 dest_obj = field_obj.rel.to.smart_get(data[field_name], 574 valid_only=False) 575 if to_human_readable: 576 # parameterized_jobs do not have a name_field 577 if (field_name != 'parameterized_job' and 578 dest_obj.name_field is not None): 579 data[field_name] = getattr(dest_obj, 580 dest_obj.name_field) 581 else: 582 data[field_name] = dest_obj 583 584 585 586 587 def _validate_unique(self): 588 """\ 589 Validate that unique fields are unique. Django manipulators do 590 this too, but they're a huge pain to use manually. Trust me. 591 """ 592 errors = {} 593 cls = type(self) 594 field_dict = self.get_field_dict() 595 manager = cls.get_valid_manager() 596 for field_name, field_obj in field_dict.iteritems(): 597 if not field_obj.unique: 598 continue 599 600 value = getattr(self, field_name) 601 if value is None and field_obj.auto_created: 602 # don't bother checking autoincrement fields about to be 603 # generated 604 continue 605 606 existing_objs = manager.filter(**{field_name : value}) 607 num_existing = existing_objs.count() 608 609 if num_existing == 0: 610 continue 611 if num_existing == 1 and existing_objs[0].id == self.id: 612 continue 613 errors[field_name] = ( 614 'This value must be unique (%s)' % (value)) 615 return errors 616 617 618 def _validate(self): 619 """ 620 First coerces all fields on this instance to their proper Python types. 621 Then runs validation on every field. Returns a dictionary of 622 field_name -> error_list. 623 624 Based on validate() from django.db.models.Model in Django 0.96, which 625 was removed in Django 1.0. It should reappear in a later version. See: 626 http://code.djangoproject.com/ticket/6845 627 """ 628 error_dict = {} 629 for f in self._meta.fields: 630 try: 631 python_value = f.to_python( 632 getattr(self, f.attname, f.get_default())) 633 except django.core.exceptions.ValidationError, e: 634 error_dict[f.name] = str(e) 635 continue 636 637 if not f.blank and not python_value: 638 error_dict[f.name] = 'This field is required.' 639 continue 640 641 setattr(self, f.attname, python_value) 642 643 return error_dict 644 645 646 def do_validate(self): 647 errors = self._validate() 648 unique_errors = self._validate_unique() 649 for field_name, error in unique_errors.iteritems(): 650 errors.setdefault(field_name, error) 651 if errors: 652 raise ValidationError(errors) 653 654 655 # actually (externally) useful methods follow 656 657 @classmethod 658 def add_object(cls, data={}, **kwargs): 659 """\ 660 Returns a new object created with the given data (a dictionary 661 mapping field names to values). Merges any extra keyword args 662 into data. 663 """ 664 data = dict(data) 665 data.update(kwargs) 666 data = cls.prepare_data_args(data) 667 cls.convert_human_readable_values(data) 668 data = cls.provide_default_values(data) 669 670 obj = cls(**data) 671 obj.do_validate() 672 obj.save() 673 return obj 674 675 676 def update_object(self, data={}, **kwargs): 677 """\ 678 Updates the object with the given data (a dictionary mapping 679 field names to values). Merges any extra keyword args into 680 data. 681 """ 682 data = dict(data) 683 data.update(kwargs) 684 data = self.prepare_data_args(data) 685 self.convert_human_readable_values(data) 686 for field_name, value in data.iteritems(): 687 setattr(self, field_name, value) 688 self.do_validate() 689 self.save() 690 691 692 # see query_objects() 693 _SPECIAL_FILTER_KEYS = ('query_start', 'query_limit', 'sort_by', 694 'extra_args', 'extra_where', 'no_distinct') 695 696 697 @classmethod 698 def _extract_special_params(cls, filter_data): 699 """ 700 @returns a tuple of dicts (special_params, regular_filters), where 701 special_params contains the parameters we handle specially and 702 regular_filters is the remaining data to be handled by Django. 703 """ 704 regular_filters = dict(filter_data) 705 special_params = {} 706 for key in cls._SPECIAL_FILTER_KEYS: 707 if key in regular_filters: 708 special_params[key] = regular_filters.pop(key) 709 return special_params, regular_filters 710 711 712 @classmethod 713 def apply_presentation(cls, query, filter_data): 714 """ 715 Apply presentation parameters -- sorting and paging -- to the given 716 query. 717 @returns new query with presentation applied 718 """ 719 special_params, _ = cls._extract_special_params(filter_data) 720 sort_by = special_params.get('sort_by', None) 721 if sort_by: 722 assert isinstance(sort_by, list) or isinstance(sort_by, tuple) 723 query = query.extra(order_by=sort_by) 724 725 query_start = special_params.get('query_start', None) 726 query_limit = special_params.get('query_limit', None) 727 if query_start is not None: 728 if query_limit is None: 729 raise ValueError('Cannot pass query_start without query_limit') 730 # query_limit is passed as a page size 731 query_limit += query_start 732 return query[query_start:query_limit] 733 734 735 @classmethod 736 def query_objects(cls, filter_data, valid_only=True, initial_query=None, 737 apply_presentation=True): 738 """\ 739 Returns a QuerySet object for querying the given model_class 740 with the given filter_data. Optional special arguments in 741 filter_data include: 742 -query_start: index of first return to return 743 -query_limit: maximum number of results to return 744 -sort_by: list of fields to sort on. prefixing a '-' onto a 745 field name changes the sort to descending order. 746 -extra_args: keyword args to pass to query.extra() (see Django 747 DB layer documentation) 748 -extra_where: extra WHERE clause to append 749 -no_distinct: if True, a DISTINCT will not be added to the SELECT 750 """ 751 special_params, regular_filters = cls._extract_special_params( 752 filter_data) 753 754 if initial_query is None: 755 if valid_only: 756 initial_query = cls.get_valid_manager() 757 else: 758 initial_query = cls.objects 759 760 query = initial_query.filter(**regular_filters) 761 762 use_distinct = not special_params.get('no_distinct', False) 763 if use_distinct: 764 query = query.distinct() 765 766 extra_args = special_params.get('extra_args', {}) 767 extra_where = special_params.get('extra_where', None) 768 if extra_where: 769 # escape %'s 770 extra_where = cls.objects.escape_user_sql(extra_where) 771 extra_args.setdefault('where', []).append(extra_where) 772 if extra_args: 773 query = query.extra(**extra_args) 774 # TODO: Use readonly connection for these queries. 775 # This has been disabled, because it's not used anyway, as the 776 # configured readonly user is the same as the real user anyway. 777 778 if apply_presentation: 779 query = cls.apply_presentation(query, filter_data) 780 781 return query 782 783 784 @classmethod 785 def query_count(cls, filter_data, initial_query=None): 786 """\ 787 Like query_objects, but retreive only the count of results. 788 """ 789 filter_data.pop('query_start', None) 790 filter_data.pop('query_limit', None) 791 query = cls.query_objects(filter_data, initial_query=initial_query) 792 return query.count() 793 794 795 @classmethod 796 def clean_object_dicts(cls, field_dicts): 797 """\ 798 Take a list of dicts corresponding to object (as returned by 799 query.values()) and clean the data to be more suitable for 800 returning to the user. 801 """ 802 for field_dict in field_dicts: 803 cls.clean_foreign_keys(field_dict) 804 cls._convert_booleans(field_dict) 805 cls.convert_human_readable_values(field_dict, 806 to_human_readable=True) 807 808 809 @classmethod 810 def list_objects(cls, filter_data, initial_query=None): 811 """\ 812 Like query_objects, but return a list of dictionaries. 813 """ 814 query = cls.query_objects(filter_data, initial_query=initial_query) 815 extra_fields = query.query.extra_select.keys() 816 field_dicts = [model_object.get_object_dict(extra_fields=extra_fields) 817 for model_object in query] 818 return field_dicts 819 820 821 @classmethod 822 def smart_get(cls, id_or_name, valid_only=True): 823 """\ 824 smart_get(integer) -> get object by ID 825 smart_get(string) -> get object by name_field 826 """ 827 if valid_only: 828 manager = cls.get_valid_manager() 829 else: 830 manager = cls.objects 831 832 if isinstance(id_or_name, (int, long)): 833 return manager.get(pk=id_or_name) 834 if isinstance(id_or_name, basestring) and hasattr(cls, 'name_field'): 835 return manager.get(**{cls.name_field : id_or_name}) 836 raise ValueError( 837 'Invalid positional argument: %s (%s)' % (id_or_name, 838 type(id_or_name))) 839 840 841 @classmethod 842 def smart_get_bulk(cls, id_or_name_list): 843 invalid_inputs = [] 844 result_objects = [] 845 for id_or_name in id_or_name_list: 846 try: 847 result_objects.append(cls.smart_get(id_or_name)) 848 except cls.DoesNotExist: 849 invalid_inputs.append(id_or_name) 850 if invalid_inputs: 851 raise cls.DoesNotExist('The following %ss do not exist: %s' 852 % (cls.__name__.lower(), 853 ', '.join(invalid_inputs))) 854 return result_objects 855 856 857 def get_object_dict(self, extra_fields=None): 858 """\ 859 Return a dictionary mapping fields to this object's values. @param 860 extra_fields: list of extra attribute names to include, in addition to 861 the fields defined on this object. 862 """ 863 fields = self.get_field_dict().keys() 864 if extra_fields: 865 fields += extra_fields 866 object_dict = dict((field_name, getattr(self, field_name)) 867 for field_name in fields) 868 self.clean_object_dicts([object_dict]) 869 self._postprocess_object_dict(object_dict) 870 return object_dict 871 872 873 def _postprocess_object_dict(self, object_dict): 874 """For subclasses to override.""" 875 pass 876 877 878 @classmethod 879 def get_valid_manager(cls): 880 return cls.objects 881 882 883 def _record_attributes(self, attributes): 884 """ 885 See on_attribute_changed. 886 """ 887 assert not isinstance(attributes, basestring) 888 self._recorded_attributes = dict((attribute, getattr(self, attribute)) 889 for attribute in attributes) 890 891 892 def _check_for_updated_attributes(self): 893 """ 894 See on_attribute_changed. 895 """ 896 for attribute, original_value in self._recorded_attributes.iteritems(): 897 new_value = getattr(self, attribute) 898 if original_value != new_value: 899 self.on_attribute_changed(attribute, original_value) 900 self._record_attributes(self._recorded_attributes.keys()) 901 902 903 def on_attribute_changed(self, attribute, old_value): 904 """ 905 Called whenever an attribute is updated. To be overridden. 906 907 To use this method, you must: 908 * call _record_attributes() from __init__() (after making the super 909 call) with a list of attributes for which you want to be notified upon 910 change. 911 * call _check_for_updated_attributes() from save(). 912 """ 913 pass 914 915 916 def serialize(self, include_dependencies=True): 917 """Serializes the object with dependencies. 918 919 The variable SERIALIZATION_LINKS_TO_FOLLOW defines which dependencies 920 this function will serialize with the object. 921 922 @param include_dependencies: Whether or not to follow relations to 923 objects this object depends on. 924 This parameter is used when uploading 925 jobs from a shard to the master, as the 926 master already has all the dependent 927 objects. 928 929 @returns: Dictionary representation of the object. 930 """ 931 serialized = {} 932 for field in self._meta.concrete_model._meta.local_fields: 933 if field.rel is None: 934 serialized[field.name] = field._get_val_from_obj(self) 935 elif field.name in self.SERIALIZATION_LINKS_TO_KEEP: 936 # attname will contain "_id" suffix for foreign keys, 937 # e.g. HostAttribute.host will be serialized as 'host_id'. 938 # Use it for easy deserialization. 939 serialized[field.attname] = field._get_val_from_obj(self) 940 941 if include_dependencies: 942 for link in self.SERIALIZATION_LINKS_TO_FOLLOW: 943 serialized[link] = self._serialize_relation(link) 944 945 return serialized 946 947 948 def _serialize_relation(self, link): 949 """Serializes dependent objects given the name of the relation. 950 951 @param link: Name of the relation to take objects from. 952 953 @returns For To-Many relationships a list of the serialized related 954 objects, for To-One relationships the serialized related object. 955 """ 956 try: 957 attr = getattr(self, link) 958 except AttributeError: 959 # One-To-One relationships that point to None may raise this 960 return None 961 962 if attr is None: 963 return None 964 if hasattr(attr, 'all'): 965 return [obj.serialize() for obj in attr.all()] 966 return attr.serialize() 967 968 969 @classmethod 970 def _split_local_from_foreign_values(cls, data): 971 """This splits local from foreign values in a serialized object. 972 973 @param data: The serialized object. 974 975 @returns A tuple of two lists, both containing tuples in the form 976 (link_name, link_value). The first list contains all links 977 for local fields, the second one contains those for foreign 978 fields/objects. 979 """ 980 links_to_local_values, links_to_related_values = [], [] 981 for link, value in data.iteritems(): 982 if link in cls.SERIALIZATION_LINKS_TO_FOLLOW: 983 # It's a foreign key 984 links_to_related_values.append((link, value)) 985 else: 986 # It's a local attribute or a foreign key 987 # we don't want to follow. 988 links_to_local_values.append((link, value)) 989 return links_to_local_values, links_to_related_values 990 991 992 @classmethod 993 def _filter_update_allowed_fields(cls, data): 994 """Filters data and returns only files that updates are allowed on. 995 996 This is i.e. needed for syncing aborted bits from the master to shards. 997 998 Local links are only allowed to be updated, if they are in 999 SERIALIZATION_LOCAL_LINKS_TO_UPDATE. 1000 Overwriting existing values is allowed in order to be able to sync i.e. 1001 the aborted bit from the master to a shard. 1002 1003 The whitelisting mechanism is in place to prevent overwriting local 1004 status: If all fields were overwritten, jobs would be completely be 1005 set back to their original (unstarted) state. 1006 1007 @param data: List with tuples of the form (link_name, link_value), as 1008 returned by _split_local_from_foreign_values. 1009 1010 @returns List of the same format as data, but only containing data for 1011 fields that updates are allowed on. 1012 """ 1013 return [pair for pair in data 1014 if pair[0] in cls.SERIALIZATION_LOCAL_LINKS_TO_UPDATE] 1015 1016 1017 @classmethod 1018 def delete_matching_record(cls, **filter_args): 1019 """Delete records matching the filter. 1020 1021 @param filter_args: Arguments for the django filter 1022 used to locate the record to delete. 1023 """ 1024 try: 1025 existing_record = cls.objects.get(**filter_args) 1026 except cls.DoesNotExist: 1027 return 1028 existing_record.delete() 1029 1030 1031 def _deserialize_local(self, data): 1032 """Set local attributes from a list of tuples. 1033 1034 @param data: List of tuples like returned by 1035 _split_local_from_foreign_values. 1036 """ 1037 if not data: 1038 return 1039 1040 for link, value in data: 1041 setattr(self, link, value) 1042 # Overwridden save() methods are prone to errors, so don't execute them. 1043 # This is because: 1044 # - the overwritten methods depend on ACL groups that don't yet exist 1045 # and don't handle errors 1046 # - the overwritten methods think this object already exists in the db 1047 # because the id is already set 1048 super(type(self), self).save() 1049 1050 1051 def _deserialize_relations(self, data): 1052 """Set foreign attributes from a list of tuples. 1053 1054 This deserialized the related objects using their own deserialize() 1055 function and then sets the relation. 1056 1057 @param data: List of tuples like returned by 1058 _split_local_from_foreign_values. 1059 """ 1060 for link, value in data: 1061 self._deserialize_relation(link, value) 1062 # See comment in _deserialize_local 1063 super(type(self), self).save() 1064 1065 1066 @classmethod 1067 def get_record(cls, data): 1068 """Retrieve a record with the data in the given input arg. 1069 1070 @param data: A dictionary containing the information to use in a query 1071 for data. If child models have different constraints of 1072 uniqueness they should override this model. 1073 1074 @return: An object with matching data. 1075 1076 @raises DoesNotExist: If a record with the given data doesn't exist. 1077 """ 1078 return cls.objects.get(id=data['id']) 1079 1080 1081 @classmethod 1082 def deserialize(cls, data): 1083 """Recursively deserializes and saves an object with it's dependencies. 1084 1085 This takes the result of the serialize method and creates objects 1086 in the database that are just like the original. 1087 1088 If an object of the same type with the same id already exists, it's 1089 local values will be left untouched, unless they are explicitly 1090 whitelisted in SERIALIZATION_LOCAL_LINKS_TO_UPDATE. 1091 1092 Deserialize will always recursively propagate to all related objects 1093 present in data though. 1094 I.e. this is necessary to add users to an already existing acl-group. 1095 1096 @param data: Representation of an object and its dependencies, as 1097 returned by serialize. 1098 1099 @returns: The object represented by data if it didn't exist before, 1100 otherwise the object that existed before and has the same type 1101 and id as the one described by data. 1102 """ 1103 if data is None: 1104 return None 1105 1106 local, related = cls._split_local_from_foreign_values(data) 1107 try: 1108 instance = cls.get_record(data) 1109 local = cls._filter_update_allowed_fields(local) 1110 except cls.DoesNotExist: 1111 instance = cls() 1112 1113 instance._deserialize_local(local) 1114 instance._deserialize_relations(related) 1115 1116 return instance 1117 1118 1119 def sanity_check_update_from_shard(self, shard, updated_serialized, 1120 *args, **kwargs): 1121 """Check if an update sent from a shard is legitimate. 1122 1123 @raises error.UnallowedRecordsSentToMaster if an update is not 1124 legitimate. 1125 """ 1126 raise NotImplementedError( 1127 'sanity_check_update_from_shard must be implemented by subclass %s ' 1128 'for type %s' % type(self)) 1129 1130 1131 @transaction.commit_on_success 1132 def update_from_serialized(self, serialized): 1133 """Updates local fields of an existing object from a serialized form. 1134 1135 This is different than the normal deserialize() in the way that it 1136 does update local values, which deserialize doesn't, but doesn't 1137 recursively propagate to related objects, which deserialize() does. 1138 1139 The use case of this function is to update job records on the master 1140 after the jobs have been executed on a slave, as the master is not 1141 interested in updates for users, labels, specialtasks, etc. 1142 1143 @param serialized: Representation of an object and its dependencies, as 1144 returned by serialize. 1145 1146 @raises ValueError: if serialized contains related objects, i.e. not 1147 only local fields. 1148 """ 1149 local, related = ( 1150 self._split_local_from_foreign_values(serialized)) 1151 if related: 1152 raise ValueError('Serialized must not contain foreign ' 1153 'objects: %s' % related) 1154 1155 self._deserialize_local(local) 1156 1157 1158 def custom_deserialize_relation(self, link, data): 1159 """Allows overriding the deserialization behaviour by subclasses.""" 1160 raise NotImplementedError( 1161 'custom_deserialize_relation must be implemented by subclass %s ' 1162 'for relation %s' % (type(self), link)) 1163 1164 1165 def _deserialize_relation(self, link, data): 1166 """Deserializes related objects and sets references on this object. 1167 1168 Relations that point to a list of objects are handled automatically. 1169 For many-to-one or one-to-one relations custom_deserialize_relation 1170 must be overridden by the subclass. 1171 1172 Related objects are deserialized using their deserialize() method. 1173 Thereby they and their dependencies are created if they don't exist 1174 and saved to the database. 1175 1176 @param link: Name of the relation. 1177 @param data: Serialized representation of the related object(s). 1178 This means a list of dictionaries for to-many relations, 1179 just a dictionary for to-one relations. 1180 """ 1181 field = getattr(self, link) 1182 1183 if field and hasattr(field, 'all'): 1184 self._deserialize_2m_relation(link, data, field.model) 1185 else: 1186 self.custom_deserialize_relation(link, data) 1187 1188 1189 def _deserialize_2m_relation(self, link, data, related_class): 1190 """Deserialize related objects for one to-many relationship. 1191 1192 @param link: Name of the relation. 1193 @param data: Serialized representation of the related objects. 1194 This is a list with of dictionaries. 1195 @param related_class: A class representing a django model, with which 1196 this class has a one-to-many relationship. 1197 """ 1198 relation_set = getattr(self, link) 1199 if related_class == self.get_attribute_model(): 1200 # When deserializing a model together with 1201 # its attributes, clear all the exising attributes to ensure 1202 # db consistency. Note 'update' won't be sufficient, as we also 1203 # want to remove any attributes that no longer exist in |data|. 1204 # 1205 # core_filters is a dictionary of filters, defines how 1206 # RelatedMangager would query for the 1-to-many relationship. E.g. 1207 # Host.objects.get( 1208 # id=20).hostattribute_set.core_filters = {host_id:20} 1209 # We use it to delete objects related to the current object. 1210 related_class.objects.filter(**relation_set.core_filters).delete() 1211 for serialized in data: 1212 relation_set.add(related_class.deserialize(serialized)) 1213 1214 1215 @classmethod 1216 def get_attribute_model(cls): 1217 """Return the attribute model. 1218 1219 Subclass with attribute-like model should override this to 1220 return the attribute model class. This method will be 1221 called by _deserialize_2m_relation to determine whether 1222 to clear the one-to-many relations first on deserialization of object. 1223 """ 1224 return None 1225 1226 1227class ModelWithInvalid(ModelExtensions): 1228 """ 1229 Overrides model methods save() and delete() to support invalidation in 1230 place of actual deletion. Subclasses must have a boolean "invalid" 1231 field. 1232 """ 1233 1234 def save(self, *args, **kwargs): 1235 first_time = (self.id is None) 1236 if first_time: 1237 # see if this object was previously added and invalidated 1238 my_name = getattr(self, self.name_field) 1239 filters = {self.name_field : my_name, 'invalid' : True} 1240 try: 1241 old_object = self.__class__.objects.get(**filters) 1242 self.resurrect_object(old_object) 1243 except self.DoesNotExist: 1244 # no existing object 1245 pass 1246 1247 super(ModelWithInvalid, self).save(*args, **kwargs) 1248 1249 1250 def resurrect_object(self, old_object): 1251 """ 1252 Called when self is about to be saved for the first time and is actually 1253 "undeleting" a previously deleted object. Can be overridden by 1254 subclasses to copy data as desired from the deleted entry (but this 1255 superclass implementation must normally be called). 1256 """ 1257 self.id = old_object.id 1258 1259 1260 def clean_object(self): 1261 """ 1262 This method is called when an object is marked invalid. 1263 Subclasses should override this to clean up relationships that 1264 should no longer exist if the object were deleted. 1265 """ 1266 pass 1267 1268 1269 def delete(self): 1270 self.invalid = self.invalid 1271 assert not self.invalid 1272 self.invalid = True 1273 self.save() 1274 self.clean_object() 1275 1276 1277 @classmethod 1278 def get_valid_manager(cls): 1279 return cls.valid_objects 1280 1281 1282 class Manipulator(object): 1283 """ 1284 Force default manipulators to look only at valid objects - 1285 otherwise they will match against invalid objects when checking 1286 uniqueness. 1287 """ 1288 @classmethod 1289 def _prepare(cls, model): 1290 super(ModelWithInvalid.Manipulator, cls)._prepare(model) 1291 cls.manager = model.valid_objects 1292 1293 1294class ModelWithAttributes(object): 1295 """ 1296 Mixin class for models that have an attribute model associated with them. 1297 The attribute model is assumed to have its value field named "value". 1298 """ 1299 1300 def _get_attribute_model_and_args(self, attribute): 1301 """ 1302 Subclasses should override this to return a tuple (attribute_model, 1303 keyword_args), where attribute_model is a model class and keyword_args 1304 is a dict of args to pass to attribute_model.objects.get() to get an 1305 instance of the given attribute on this object. 1306 """ 1307 raise NotImplementedError 1308 1309 1310 def set_attribute(self, attribute, value): 1311 attribute_model, get_args = self._get_attribute_model_and_args( 1312 attribute) 1313 attribute_object, _ = attribute_model.objects.get_or_create(**get_args) 1314 attribute_object.value = value 1315 attribute_object.save() 1316 1317 1318 def delete_attribute(self, attribute): 1319 attribute_model, get_args = self._get_attribute_model_and_args( 1320 attribute) 1321 try: 1322 attribute_model.objects.get(**get_args).delete() 1323 except attribute_model.DoesNotExist: 1324 pass 1325 1326 1327 def set_or_delete_attribute(self, attribute, value): 1328 if value is None: 1329 self.delete_attribute(attribute) 1330 else: 1331 self.set_attribute(attribute, value) 1332 1333 1334class ModelWithHashManager(dbmodels.Manager): 1335 """Manager for use with the ModelWithHash abstract model class""" 1336 1337 def create(self, **kwargs): 1338 raise Exception('ModelWithHash manager should use get_or_create() ' 1339 'instead of create()') 1340 1341 1342 def get_or_create(self, **kwargs): 1343 kwargs['the_hash'] = self.model._compute_hash(**kwargs) 1344 return super(ModelWithHashManager, self).get_or_create(**kwargs) 1345 1346 1347class ModelWithHash(dbmodels.Model): 1348 """Superclass with methods for dealing with a hash column""" 1349 1350 the_hash = dbmodels.CharField(max_length=40, unique=True) 1351 1352 objects = ModelWithHashManager() 1353 1354 class Meta: 1355 abstract = True 1356 1357 1358 @classmethod 1359 def _compute_hash(cls, **kwargs): 1360 raise NotImplementedError('Subclasses must override _compute_hash()') 1361 1362 1363 def save(self, force_insert=False, **kwargs): 1364 """Prevents saving the model in most cases 1365 1366 We want these models to be immutable, so the generic save() operation 1367 will not work. These models should be instantiated through their the 1368 model.objects.get_or_create() method instead. 1369 1370 The exception is that save(force_insert=True) will be allowed, since 1371 that creates a new row. However, the preferred way to make instances of 1372 these models is through the get_or_create() method. 1373 """ 1374 if not force_insert: 1375 # Allow a forced insert to happen; if it's a duplicate, the unique 1376 # constraint will catch it later anyways 1377 raise Exception('ModelWithHash is immutable') 1378 super(ModelWithHash, self).save(force_insert=force_insert, **kwargs) 1379