1""" 2Extensions to Django's model logic. 3""" 4 5import django.core.exceptions 6from django.db import connection 7from django.db import connections 8from django.db import models as dbmodels 9from django.db import transaction 10from django.db.models.sql import query 11import django.db.models.sql.where 12 13from autotest_lib.client.common_lib import error 14from autotest_lib.frontend.afe import rdb_model_extensions 15 16 17class ValidationError(django.core.exceptions.ValidationError): 18 """\ 19 Data validation error in adding or updating an object. The associated 20 value is a dictionary mapping field names to error strings. 21 """ 22 23def _quote_name(name): 24 """Shorthand for connection.ops.quote_name().""" 25 return connection.ops.quote_name(name) 26 27 28class LeasedHostManager(dbmodels.Manager): 29 """Query manager for unleased, unlocked hosts. 30 """ 31 def get_query_set(self): 32 return (super(LeasedHostManager, self).get_query_set().filter( 33 leased=0, locked=0)) 34 35 36class ExtendedManager(dbmodels.Manager): 37 """\ 38 Extended manager supporting subquery filtering. 39 """ 40 41 class CustomQuery(query.Query): 42 def __init__(self, *args, **kwargs): 43 super(ExtendedManager.CustomQuery, self).__init__(*args, **kwargs) 44 self._custom_joins = [] 45 46 47 def clone(self, klass=None, **kwargs): 48 obj = super(ExtendedManager.CustomQuery, self).clone(klass) 49 obj._custom_joins = list(self._custom_joins) 50 return obj 51 52 53 def combine(self, rhs, connector): 54 super(ExtendedManager.CustomQuery, self).combine(rhs, connector) 55 if hasattr(rhs, '_custom_joins'): 56 self._custom_joins.extend(rhs._custom_joins) 57 58 59 def add_custom_join(self, table, condition, join_type, 60 condition_values=(), alias=None): 61 if alias is None: 62 alias = table 63 join_dict = dict(table=table, 64 condition=condition, 65 condition_values=condition_values, 66 join_type=join_type, 67 alias=alias) 68 self._custom_joins.append(join_dict) 69 70 71 @classmethod 72 def convert_query(self, query_set): 73 """ 74 Convert the query set's "query" attribute to a CustomQuery. 75 """ 76 # Make a copy of the query set 77 query_set = query_set.all() 78 query_set.query = query_set.query.clone( 79 klass=ExtendedManager.CustomQuery, 80 _custom_joins=[]) 81 return query_set 82 83 84 class _WhereClause(object): 85 """Object allowing us to inject arbitrary SQL into Django queries. 86 87 By using this instead of extra(where=...), we can still freely combine 88 queries with & and |. 89 """ 90 def __init__(self, clause, values=()): 91 self._clause = clause 92 self._values = values 93 94 95 def as_sql(self, qn=None, connection=None): 96 return self._clause, self._values 97 98 99 def relabel_aliases(self, change_map): 100 return 101 102 103 def add_join(self, query_set, join_table, join_key, join_condition='', 104 join_condition_values=(), join_from_key=None, alias=None, 105 suffix='', exclude=False, force_left_join=False): 106 """Add a join to query_set. 107 108 Join looks like this: 109 (INNER|LEFT) JOIN <join_table> AS <alias> 110 ON (<this table>.<join_from_key> = <join_table>.<join_key> 111 and <join_condition>) 112 113 @param join_table table to join to 114 @param join_key field referencing back to this model to use for the join 115 @param join_condition extra condition for the ON clause of the join 116 @param join_condition_values values to substitute into join_condition 117 @param join_from_key column on this model to join from. 118 @param alias alias to use for for join 119 @param suffix suffix to add to join_table for the join alias, if no 120 alias is provided 121 @param exclude if true, exclude rows that match this join (will use a 122 LEFT OUTER JOIN and an appropriate WHERE condition) 123 @param force_left_join - if true, a LEFT OUTER JOIN will be used 124 instead of an INNER JOIN regardless of other options 125 """ 126 join_from_table = query_set.model._meta.db_table 127 if join_from_key is None: 128 join_from_key = self.model._meta.pk.name 129 if alias is None: 130 alias = join_table + suffix 131 full_join_key = _quote_name(alias) + '.' + _quote_name(join_key) 132 full_join_condition = '%s = %s.%s' % (full_join_key, 133 _quote_name(join_from_table), 134 _quote_name(join_from_key)) 135 if join_condition: 136 full_join_condition += ' AND (' + join_condition + ')' 137 if exclude or force_left_join: 138 join_type = query_set.query.LOUTER 139 else: 140 join_type = query_set.query.INNER 141 142 query_set = self.CustomQuery.convert_query(query_set) 143 query_set.query.add_custom_join(join_table, 144 full_join_condition, 145 join_type, 146 condition_values=join_condition_values, 147 alias=alias) 148 149 if exclude: 150 query_set = query_set.extra(where=[full_join_key + ' IS NULL']) 151 152 return query_set 153 154 155 def _info_for_many_to_one_join(self, field, join_to_query, alias): 156 """ 157 @param field: the ForeignKey field on the related model 158 @param join_to_query: the query over the related model that we're 159 joining to 160 @param alias: alias of joined table 161 """ 162 info = {} 163 rhs_table = join_to_query.model._meta.db_table 164 info['rhs_table'] = rhs_table 165 info['rhs_column'] = field.column 166 info['lhs_column'] = field.rel.get_related_field().column 167 rhs_where = join_to_query.query.where 168 rhs_where.relabel_aliases({rhs_table: alias}) 169 compiler = join_to_query.query.get_compiler(using=join_to_query.db) 170 initial_clause, values = compiler.as_sql() 171 # initial_clause is compiled from `join_to_query`, which is a SELECT 172 # query returns at most one record. For it to be used in WHERE clause, 173 # it must be converted to a boolean value using EXISTS. 174 all_clauses = ('EXISTS (%s)' % initial_clause,) 175 if hasattr(join_to_query.query, 'extra_where'): 176 all_clauses += join_to_query.query.extra_where 177 info['where_clause'] = ( 178 ' AND '.join('(%s)' % clause for clause in all_clauses)) 179 info['values'] = values 180 return info 181 182 183 def _info_for_many_to_many_join(self, m2m_field, join_to_query, alias, 184 m2m_is_on_this_model): 185 """ 186 @param m2m_field: a Django field representing the M2M relationship. 187 It uses a pivot table with the following structure: 188 this model table <---> M2M pivot table <---> joined model table 189 @param join_to_query: the query over the related model that we're 190 joining to. 191 @param alias: alias of joined table 192 """ 193 if m2m_is_on_this_model: 194 # referenced field on this model 195 lhs_id_field = self.model._meta.pk 196 # foreign key on the pivot table referencing lhs_id_field 197 m2m_lhs_column = m2m_field.m2m_column_name() 198 # foreign key on the pivot table referencing rhd_id_field 199 m2m_rhs_column = m2m_field.m2m_reverse_name() 200 # referenced field on related model 201 rhs_id_field = m2m_field.rel.get_related_field() 202 else: 203 lhs_id_field = m2m_field.rel.get_related_field() 204 m2m_lhs_column = m2m_field.m2m_reverse_name() 205 m2m_rhs_column = m2m_field.m2m_column_name() 206 rhs_id_field = join_to_query.model._meta.pk 207 208 info = {} 209 info['rhs_table'] = m2m_field.m2m_db_table() 210 info['rhs_column'] = m2m_lhs_column 211 info['lhs_column'] = lhs_id_field.column 212 213 # select the ID of related models relevant to this join. we can only do 214 # a single join, so we need to gather this information up front and 215 # include it in the join condition. 216 rhs_ids = join_to_query.values_list(rhs_id_field.attname, flat=True) 217 assert len(rhs_ids) == 1, ('Many-to-many custom field joins can only ' 218 'match a single related object.') 219 rhs_id = rhs_ids[0] 220 221 info['where_clause'] = '%s.%s = %s' % (_quote_name(alias), 222 _quote_name(m2m_rhs_column), 223 rhs_id) 224 info['values'] = () 225 return info 226 227 228 def join_custom_field(self, query_set, join_to_query, alias, 229 left_join=True): 230 """Join to a related model to create a custom field in the given query. 231 232 This method is used to construct a custom field on the given query based 233 on a many-valued relationsip. join_to_query should be a simple query 234 (no joins) on the related model which returns at most one related row 235 per instance of this model. 236 237 For many-to-one relationships, the joined table contains the matching 238 row from the related model it one is related, NULL otherwise. 239 240 For many-to-many relationships, the joined table contains the matching 241 row if it's related, NULL otherwise. 242 """ 243 relationship_type, field = self.determine_relationship( 244 join_to_query.model) 245 246 if relationship_type == self.MANY_TO_ONE: 247 info = self._info_for_many_to_one_join(field, join_to_query, alias) 248 elif relationship_type == self.M2M_ON_RELATED_MODEL: 249 info = self._info_for_many_to_many_join( 250 m2m_field=field, join_to_query=join_to_query, alias=alias, 251 m2m_is_on_this_model=False) 252 elif relationship_type ==self.M2M_ON_THIS_MODEL: 253 info = self._info_for_many_to_many_join( 254 m2m_field=field, join_to_query=join_to_query, alias=alias, 255 m2m_is_on_this_model=True) 256 257 return self.add_join(query_set, info['rhs_table'], info['rhs_column'], 258 join_from_key=info['lhs_column'], 259 join_condition=info['where_clause'], 260 join_condition_values=info['values'], 261 alias=alias, 262 force_left_join=left_join) 263 264 265 def add_where(self, query_set, where, values=()): 266 query_set = query_set.all() 267 query_set.query.where.add(self._WhereClause(where, values), 268 django.db.models.sql.where.AND) 269 return query_set 270 271 272 def _get_quoted_field(self, table, field): 273 return _quote_name(table) + '.' + _quote_name(field) 274 275 276 def get_key_on_this_table(self, key_field=None): 277 if key_field is None: 278 # default to primary key 279 key_field = self.model._meta.pk.column 280 return self._get_quoted_field(self.model._meta.db_table, key_field) 281 282 283 def escape_user_sql(self, sql): 284 return sql.replace('%', '%%') 285 286 287 def _custom_select_query(self, query_set, selects): 288 """Execute a custom select query. 289 290 @param query_set: query set as returned by query_objects. 291 @param selects: Tables/Columns to select, e.g. tko_test_labels_list.id. 292 293 @returns: Result of the query as returned by cursor.fetchall(). 294 """ 295 compiler = query_set.query.get_compiler(using=query_set.db) 296 sql, params = compiler.as_sql() 297 from_ = sql[sql.find(' FROM'):] 298 299 if query_set.query.distinct: 300 distinct = 'DISTINCT ' 301 else: 302 distinct = '' 303 304 sql_query = ('SELECT ' + distinct + ','.join(selects) + from_) 305 # Chose the connection that's responsible for this type of object 306 cursor = connections[query_set.db].cursor() 307 cursor.execute(sql_query, params) 308 return cursor.fetchall() 309 310 311 def _is_relation_to(self, field, model_class): 312 return field.rel and field.rel.to is model_class 313 314 315 MANY_TO_ONE = object() 316 M2M_ON_RELATED_MODEL = object() 317 M2M_ON_THIS_MODEL = object() 318 319 def determine_relationship(self, related_model): 320 """ 321 Determine the relationship between this model and related_model. 322 323 related_model must have some sort of many-valued relationship to this 324 manager's model. 325 @returns (relationship_type, field), where relationship_type is one of 326 MANY_TO_ONE, M2M_ON_RELATED_MODEL, M2M_ON_THIS_MODEL, and field 327 is the Django field object for the relationship. 328 """ 329 # look for a foreign key field on related_model relating to this model 330 for field in related_model._meta.fields: 331 if self._is_relation_to(field, self.model): 332 return self.MANY_TO_ONE, field 333 334 # look for an M2M field on related_model relating to this model 335 for field in related_model._meta.many_to_many: 336 if self._is_relation_to(field, self.model): 337 return self.M2M_ON_RELATED_MODEL, field 338 339 # maybe this model has the many-to-many field 340 for field in self.model._meta.many_to_many: 341 if self._is_relation_to(field, related_model): 342 return self.M2M_ON_THIS_MODEL, field 343 344 raise ValueError('%s has no relation to %s' % 345 (related_model, self.model)) 346 347 348 def _get_pivot_iterator(self, base_objects_by_id, related_model): 349 """ 350 Determine the relationship between this model and related_model, and 351 return a pivot iterator. 352 @param base_objects_by_id: dict of instances of this model indexed by 353 their IDs 354 @returns a pivot iterator, which yields a tuple (base_object, 355 related_object) for each relationship between a base object and a 356 related object. all base_object instances come from base_objects_by_id. 357 Note -- this depends on Django model internals. 358 """ 359 relationship_type, field = self.determine_relationship(related_model) 360 if relationship_type == self.MANY_TO_ONE: 361 return self._many_to_one_pivot(base_objects_by_id, 362 related_model, field) 363 elif relationship_type == self.M2M_ON_RELATED_MODEL: 364 return self._many_to_many_pivot( 365 base_objects_by_id, related_model, field.m2m_db_table(), 366 field.m2m_reverse_name(), field.m2m_column_name()) 367 else: 368 assert relationship_type == self.M2M_ON_THIS_MODEL 369 return self._many_to_many_pivot( 370 base_objects_by_id, related_model, field.m2m_db_table(), 371 field.m2m_column_name(), field.m2m_reverse_name()) 372 373 374 def _many_to_one_pivot(self, base_objects_by_id, related_model, 375 foreign_key_field): 376 """ 377 @returns a pivot iterator - see _get_pivot_iterator() 378 """ 379 filter_data = {foreign_key_field.name + '__pk__in': 380 base_objects_by_id.keys()} 381 for related_object in related_model.objects.filter(**filter_data): 382 # lookup base object in the dict, rather than grabbing it from the 383 # related object. we need to return instances from the dict, not 384 # fresh instances of the same models (and grabbing model instances 385 # from the related models incurs a DB query each time). 386 base_object_id = getattr(related_object, foreign_key_field.attname) 387 base_object = base_objects_by_id[base_object_id] 388 yield base_object, related_object 389 390 391 def _query_pivot_table(self, base_objects_by_id, pivot_table, 392 pivot_from_field, pivot_to_field, related_model): 393 """ 394 @param id_list list of IDs of self.model objects to include 395 @param pivot_table the name of the pivot table 396 @param pivot_from_field a field name on pivot_table referencing 397 self.model 398 @param pivot_to_field a field name on pivot_table referencing the 399 related model. 400 @param related_model the related model 401 402 @returns pivot list of IDs (base_id, related_id) 403 """ 404 query = """ 405 SELECT %(from_field)s, %(to_field)s 406 FROM %(table)s 407 WHERE %(from_field)s IN (%(id_list)s) 408 """ % dict(from_field=pivot_from_field, 409 to_field=pivot_to_field, 410 table=pivot_table, 411 id_list=','.join(str(id_) for id_ 412 in base_objects_by_id.iterkeys())) 413 414 # Chose the connection that's responsible for this type of object 415 # The databases for related_model and the current model will always 416 # be the same, related_model is just easier to obtain here because 417 # self is only a ExtendedManager, not the object. 418 cursor = connections[related_model.objects.db].cursor() 419 cursor.execute(query) 420 return cursor.fetchall() 421 422 423 def _many_to_many_pivot(self, base_objects_by_id, related_model, 424 pivot_table, pivot_from_field, pivot_to_field): 425 """ 426 @param pivot_table: see _query_pivot_table 427 @param pivot_from_field: see _query_pivot_table 428 @param pivot_to_field: see _query_pivot_table 429 @returns a pivot iterator - see _get_pivot_iterator() 430 """ 431 id_pivot = self._query_pivot_table(base_objects_by_id, pivot_table, 432 pivot_from_field, pivot_to_field, 433 related_model) 434 435 all_related_ids = list(set(related_id for base_id, related_id 436 in id_pivot)) 437 related_objects_by_id = related_model.objects.in_bulk(all_related_ids) 438 439 for base_id, related_id in id_pivot: 440 yield base_objects_by_id[base_id], related_objects_by_id[related_id] 441 442 443 def populate_relationships(self, base_objects, related_model, 444 related_list_name): 445 """ 446 For each instance of this model in base_objects, add a field named 447 related_list_name listing all the related objects of type related_model. 448 related_model must be in a many-to-one or many-to-many relationship with 449 this model. 450 @param base_objects - list of instances of this model 451 @param related_model - model class related to this model 452 @param related_list_name - attribute name in which to store the related 453 object list. 454 """ 455 if not base_objects: 456 # if we don't bail early, we'll get a SQL error later 457 return 458 459 base_objects_by_id = dict((base_object._get_pk_val(), base_object) 460 for base_object in base_objects) 461 pivot_iterator = self._get_pivot_iterator(base_objects_by_id, 462 related_model) 463 464 for base_object in base_objects: 465 setattr(base_object, related_list_name, []) 466 467 for base_object, related_object in pivot_iterator: 468 getattr(base_object, related_list_name).append(related_object) 469 470 471class ModelWithInvalidQuerySet(dbmodels.query.QuerySet): 472 """ 473 QuerySet that handles delete() properly for models with an "invalid" bit 474 """ 475 def delete(self): 476 for model in self: 477 model.delete() 478 479 480class ModelWithInvalidManager(ExtendedManager): 481 """ 482 Manager for objects with an "invalid" bit 483 """ 484 def get_query_set(self): 485 return ModelWithInvalidQuerySet(self.model) 486 487 488class ValidObjectsManager(ModelWithInvalidManager): 489 """ 490 Manager returning only objects with invalid=False. 491 """ 492 def get_query_set(self): 493 queryset = super(ValidObjectsManager, self).get_query_set() 494 return queryset.filter(invalid=False) 495 496 497class ModelExtensions(rdb_model_extensions.ModelValidators): 498 """\ 499 Mixin with convenience functions for models, built on top of 500 the model validators in rdb_model_extensions. 501 """ 502 # TODO: at least some of these functions really belong in a custom 503 # Manager class 504 505 506 SERIALIZATION_LINKS_TO_FOLLOW = set() 507 """ 508 To be able to send jobs and hosts to shards, it's necessary to find their 509 dependencies. 510 The most generic approach for this would be to traverse all relationships 511 to other objects recursively. This would list all objects that are related 512 in any way. 513 But this approach finds too many objects: If a host should be transferred, 514 all it's relationships would be traversed. This would find an acl group. 515 If then the acl group's relationships are traversed, the relationship 516 would be followed backwards and many other hosts would be found. 517 518 This mapping tells that algorithm which relations to follow explicitly. 519 """ 520 521 522 SERIALIZATION_LINKS_TO_KEEP = set() 523 """This set stores foreign keys which we don't want to follow, but 524 still want to include in the serialized dictionary. For 525 example, we follow the relationship `Host.hostattribute_set`, 526 but we do not want to follow `HostAttributes.host_id` back to 527 to Host, which would otherwise lead to a circle. However, we still 528 like to serialize HostAttribute.`host_id`.""" 529 530 SERIALIZATION_LOCAL_LINKS_TO_UPDATE = set() 531 """ 532 On deserializion, if the object to persist already exists, local fields 533 will only be updated, if their name is in this set. 534 """ 535 536 537 @classmethod 538 def convert_human_readable_values(cls, data, to_human_readable=False): 539 """\ 540 Performs conversions on user-supplied field data, to make it 541 easier for users to pass human-readable data. 542 543 For all fields that have choice sets, convert their values 544 from human-readable strings to enum values, if necessary. This 545 allows users to pass strings instead of the corresponding 546 integer values. 547 548 For all foreign key fields, call smart_get with the supplied 549 data. This allows the user to pass either an ID value or 550 the name of the object as a string. 551 552 If to_human_readable=True, perform the inverse - i.e. convert 553 numeric values to human readable values. 554 555 This method modifies data in-place. 556 """ 557 field_dict = cls.get_field_dict() 558 for field_name in data: 559 if field_name not in field_dict or data[field_name] is None: 560 continue 561 field_obj = field_dict[field_name] 562 # convert enum values 563 if field_obj.choices: 564 for choice_data in field_obj.choices: 565 # choice_data is (value, name) 566 if to_human_readable: 567 from_val, to_val = choice_data 568 else: 569 to_val, from_val = choice_data 570 if from_val == data[field_name]: 571 data[field_name] = to_val 572 break 573 # convert foreign key values 574 elif field_obj.rel: 575 dest_obj = field_obj.rel.to.smart_get(data[field_name], 576 valid_only=False) 577 if to_human_readable: 578 # parameterized_jobs do not have a name_field 579 if (field_name != 'parameterized_job' and 580 dest_obj.name_field is not None): 581 data[field_name] = getattr(dest_obj, 582 dest_obj.name_field) 583 else: 584 data[field_name] = dest_obj 585 586 587 588 589 def _validate_unique(self): 590 """\ 591 Validate that unique fields are unique. Django manipulators do 592 this too, but they're a huge pain to use manually. Trust me. 593 """ 594 errors = {} 595 cls = type(self) 596 field_dict = self.get_field_dict() 597 manager = cls.get_valid_manager() 598 for field_name, field_obj in field_dict.iteritems(): 599 if not field_obj.unique: 600 continue 601 602 value = getattr(self, field_name) 603 if value is None and field_obj.auto_created: 604 # don't bother checking autoincrement fields about to be 605 # generated 606 continue 607 608 existing_objs = manager.filter(**{field_name : value}) 609 num_existing = existing_objs.count() 610 611 if num_existing == 0: 612 continue 613 if num_existing == 1 and existing_objs[0].id == self.id: 614 continue 615 errors[field_name] = ( 616 'This value must be unique (%s)' % (value)) 617 return errors 618 619 620 def _validate(self): 621 """ 622 First coerces all fields on this instance to their proper Python types. 623 Then runs validation on every field. Returns a dictionary of 624 field_name -> error_list. 625 626 Based on validate() from django.db.models.Model in Django 0.96, which 627 was removed in Django 1.0. It should reappear in a later version. See: 628 http://code.djangoproject.com/ticket/6845 629 """ 630 error_dict = {} 631 for f in self._meta.fields: 632 try: 633 python_value = f.to_python( 634 getattr(self, f.attname, f.get_default())) 635 except django.core.exceptions.ValidationError, e: 636 error_dict[f.name] = str(e) 637 continue 638 639 if not f.blank and not python_value: 640 error_dict[f.name] = 'This field is required.' 641 continue 642 643 setattr(self, f.attname, python_value) 644 645 return error_dict 646 647 648 def do_validate(self): 649 errors = self._validate() 650 unique_errors = self._validate_unique() 651 for field_name, error in unique_errors.iteritems(): 652 errors.setdefault(field_name, error) 653 if errors: 654 raise ValidationError(errors) 655 656 657 # actually (externally) useful methods follow 658 659 @classmethod 660 def add_object(cls, data={}, **kwargs): 661 """\ 662 Returns a new object created with the given data (a dictionary 663 mapping field names to values). Merges any extra keyword args 664 into data. 665 """ 666 data = dict(data) 667 data.update(kwargs) 668 data = cls.prepare_data_args(data) 669 cls.convert_human_readable_values(data) 670 data = cls.provide_default_values(data) 671 672 obj = cls(**data) 673 obj.do_validate() 674 obj.save() 675 return obj 676 677 678 def update_object(self, data={}, **kwargs): 679 """\ 680 Updates the object with the given data (a dictionary mapping 681 field names to values). Merges any extra keyword args into 682 data. 683 """ 684 data = dict(data) 685 data.update(kwargs) 686 data = self.prepare_data_args(data) 687 self.convert_human_readable_values(data) 688 for field_name, value in data.iteritems(): 689 setattr(self, field_name, value) 690 self.do_validate() 691 self.save() 692 693 694 # see query_objects() 695 _SPECIAL_FILTER_KEYS = ('query_start', 'query_limit', 'sort_by', 696 'extra_args', 'extra_where', 'no_distinct') 697 698 699 @classmethod 700 def _extract_special_params(cls, filter_data): 701 """ 702 @returns a tuple of dicts (special_params, regular_filters), where 703 special_params contains the parameters we handle specially and 704 regular_filters is the remaining data to be handled by Django. 705 """ 706 regular_filters = dict(filter_data) 707 special_params = {} 708 for key in cls._SPECIAL_FILTER_KEYS: 709 if key in regular_filters: 710 special_params[key] = regular_filters.pop(key) 711 return special_params, regular_filters 712 713 714 @classmethod 715 def apply_presentation(cls, query, filter_data): 716 """ 717 Apply presentation parameters -- sorting and paging -- to the given 718 query. 719 @returns new query with presentation applied 720 """ 721 special_params, _ = cls._extract_special_params(filter_data) 722 sort_by = special_params.get('sort_by', None) 723 if sort_by: 724 assert isinstance(sort_by, list) or isinstance(sort_by, tuple) 725 query = query.extra(order_by=sort_by) 726 727 query_start = special_params.get('query_start', None) 728 query_limit = special_params.get('query_limit', None) 729 if query_start is not None: 730 if query_limit is None: 731 raise ValueError('Cannot pass query_start without query_limit') 732 # query_limit is passed as a page size 733 query_limit += query_start 734 return query[query_start:query_limit] 735 736 737 @classmethod 738 def query_objects(cls, filter_data, valid_only=True, initial_query=None, 739 apply_presentation=True): 740 """\ 741 Returns a QuerySet object for querying the given model_class 742 with the given filter_data. Optional special arguments in 743 filter_data include: 744 -query_start: index of first return to return 745 -query_limit: maximum number of results to return 746 -sort_by: list of fields to sort on. prefixing a '-' onto a 747 field name changes the sort to descending order. 748 -extra_args: keyword args to pass to query.extra() (see Django 749 DB layer documentation) 750 -extra_where: extra WHERE clause to append 751 -no_distinct: if True, a DISTINCT will not be added to the SELECT 752 """ 753 special_params, regular_filters = cls._extract_special_params( 754 filter_data) 755 756 if initial_query is None: 757 if valid_only: 758 initial_query = cls.get_valid_manager() 759 else: 760 initial_query = cls.objects 761 762 query = initial_query.filter(**regular_filters) 763 764 use_distinct = not special_params.get('no_distinct', False) 765 if use_distinct: 766 query = query.distinct() 767 768 extra_args = special_params.get('extra_args', {}) 769 extra_where = special_params.get('extra_where', None) 770 if extra_where: 771 # escape %'s 772 extra_where = cls.objects.escape_user_sql(extra_where) 773 extra_args.setdefault('where', []).append(extra_where) 774 if extra_args: 775 query = query.extra(**extra_args) 776 # TODO: Use readonly connection for these queries. 777 # This has been disabled, because it's not used anyway, as the 778 # configured readonly user is the same as the real user anyway. 779 780 if apply_presentation: 781 query = cls.apply_presentation(query, filter_data) 782 783 return query 784 785 786 @classmethod 787 def query_count(cls, filter_data, initial_query=None): 788 """\ 789 Like query_objects, but retreive only the count of results. 790 """ 791 filter_data.pop('query_start', None) 792 filter_data.pop('query_limit', None) 793 query = cls.query_objects(filter_data, initial_query=initial_query) 794 return query.count() 795 796 797 @classmethod 798 def clean_object_dicts(cls, field_dicts): 799 """\ 800 Take a list of dicts corresponding to object (as returned by 801 query.values()) and clean the data to be more suitable for 802 returning to the user. 803 """ 804 for field_dict in field_dicts: 805 cls.clean_foreign_keys(field_dict) 806 cls._convert_booleans(field_dict) 807 cls.convert_human_readable_values(field_dict, 808 to_human_readable=True) 809 810 811 @classmethod 812 def list_objects(cls, filter_data, initial_query=None): 813 """\ 814 Like query_objects, but return a list of dictionaries. 815 """ 816 query = cls.query_objects(filter_data, initial_query=initial_query) 817 extra_fields = query.query.extra_select.keys() 818 field_dicts = [model_object.get_object_dict(extra_fields=extra_fields) 819 for model_object in query] 820 return field_dicts 821 822 823 @classmethod 824 def smart_get(cls, id_or_name, valid_only=True): 825 """\ 826 smart_get(integer) -> get object by ID 827 smart_get(string) -> get object by name_field 828 """ 829 if valid_only: 830 manager = cls.get_valid_manager() 831 else: 832 manager = cls.objects 833 834 if isinstance(id_or_name, (int, long)): 835 return manager.get(pk=id_or_name) 836 if isinstance(id_or_name, basestring) and hasattr(cls, 'name_field'): 837 return manager.get(**{cls.name_field : id_or_name}) 838 raise ValueError( 839 'Invalid positional argument: %s (%s)' % (id_or_name, 840 type(id_or_name))) 841 842 843 @classmethod 844 def smart_get_bulk(cls, id_or_name_list): 845 invalid_inputs = [] 846 result_objects = [] 847 for id_or_name in id_or_name_list: 848 try: 849 result_objects.append(cls.smart_get(id_or_name)) 850 except cls.DoesNotExist: 851 invalid_inputs.append(id_or_name) 852 if invalid_inputs: 853 raise cls.DoesNotExist('The following %ss do not exist: %s' 854 % (cls.__name__.lower(), 855 ', '.join(invalid_inputs))) 856 return result_objects 857 858 859 def get_object_dict(self, extra_fields=None): 860 """\ 861 Return a dictionary mapping fields to this object's values. @param 862 extra_fields: list of extra attribute names to include, in addition to 863 the fields defined on this object. 864 """ 865 fields = self.get_field_dict().keys() 866 if extra_fields: 867 fields += extra_fields 868 object_dict = dict((field_name, getattr(self, field_name)) 869 for field_name in fields) 870 self.clean_object_dicts([object_dict]) 871 self._postprocess_object_dict(object_dict) 872 return object_dict 873 874 875 def _postprocess_object_dict(self, object_dict): 876 """For subclasses to override.""" 877 pass 878 879 880 @classmethod 881 def get_valid_manager(cls): 882 return cls.objects 883 884 885 def _record_attributes(self, attributes): 886 """ 887 See on_attribute_changed. 888 """ 889 assert not isinstance(attributes, basestring) 890 self._recorded_attributes = dict((attribute, getattr(self, attribute)) 891 for attribute in attributes) 892 893 894 def _check_for_updated_attributes(self): 895 """ 896 See on_attribute_changed. 897 """ 898 for attribute, original_value in self._recorded_attributes.iteritems(): 899 new_value = getattr(self, attribute) 900 if original_value != new_value: 901 self.on_attribute_changed(attribute, original_value) 902 self._record_attributes(self._recorded_attributes.keys()) 903 904 905 def on_attribute_changed(self, attribute, old_value): 906 """ 907 Called whenever an attribute is updated. To be overridden. 908 909 To use this method, you must: 910 * call _record_attributes() from __init__() (after making the super 911 call) with a list of attributes for which you want to be notified upon 912 change. 913 * call _check_for_updated_attributes() from save(). 914 """ 915 pass 916 917 918 def serialize(self, include_dependencies=True): 919 """Serializes the object with dependencies. 920 921 The variable SERIALIZATION_LINKS_TO_FOLLOW defines which dependencies 922 this function will serialize with the object. 923 924 @param include_dependencies: Whether or not to follow relations to 925 objects this object depends on. 926 This parameter is used when uploading 927 jobs from a shard to the master, as the 928 master already has all the dependent 929 objects. 930 931 @returns: Dictionary representation of the object. 932 """ 933 serialized = {} 934 for field in self._meta.concrete_model._meta.local_fields: 935 if field.rel is None: 936 serialized[field.name] = field._get_val_from_obj(self) 937 elif field.name in self.SERIALIZATION_LINKS_TO_KEEP: 938 # attname will contain "_id" suffix for foreign keys, 939 # e.g. HostAttribute.host will be serialized as 'host_id'. 940 # Use it for easy deserialization. 941 serialized[field.attname] = field._get_val_from_obj(self) 942 943 if include_dependencies: 944 for link in self.SERIALIZATION_LINKS_TO_FOLLOW: 945 serialized[link] = self._serialize_relation(link) 946 947 return serialized 948 949 950 def _serialize_relation(self, link): 951 """Serializes dependent objects given the name of the relation. 952 953 @param link: Name of the relation to take objects from. 954 955 @returns For To-Many relationships a list of the serialized related 956 objects, for To-One relationships the serialized related object. 957 """ 958 try: 959 attr = getattr(self, link) 960 except AttributeError: 961 # One-To-One relationships that point to None may raise this 962 return None 963 964 if attr is None: 965 return None 966 if hasattr(attr, 'all'): 967 return [obj.serialize() for obj in attr.all()] 968 return attr.serialize() 969 970 971 @classmethod 972 def _split_local_from_foreign_values(cls, data): 973 """This splits local from foreign values in a serialized object. 974 975 @param data: The serialized object. 976 977 @returns A tuple of two lists, both containing tuples in the form 978 (link_name, link_value). The first list contains all links 979 for local fields, the second one contains those for foreign 980 fields/objects. 981 """ 982 links_to_local_values, links_to_related_values = [], [] 983 for link, value in data.iteritems(): 984 if link in cls.SERIALIZATION_LINKS_TO_FOLLOW: 985 # It's a foreign key 986 links_to_related_values.append((link, value)) 987 else: 988 # It's a local attribute or a foreign key 989 # we don't want to follow. 990 links_to_local_values.append((link, value)) 991 return links_to_local_values, links_to_related_values 992 993 994 @classmethod 995 def _filter_update_allowed_fields(cls, data): 996 """Filters data and returns only files that updates are allowed on. 997 998 This is i.e. needed for syncing aborted bits from the master to shards. 999 1000 Local links are only allowed to be updated, if they are in 1001 SERIALIZATION_LOCAL_LINKS_TO_UPDATE. 1002 Overwriting existing values is allowed in order to be able to sync i.e. 1003 the aborted bit from the master to a shard. 1004 1005 The whitelisting mechanism is in place to prevent overwriting local 1006 status: If all fields were overwritten, jobs would be completely be 1007 set back to their original (unstarted) state. 1008 1009 @param data: List with tuples of the form (link_name, link_value), as 1010 returned by _split_local_from_foreign_values. 1011 1012 @returns List of the same format as data, but only containing data for 1013 fields that updates are allowed on. 1014 """ 1015 return [pair for pair in data 1016 if pair[0] in cls.SERIALIZATION_LOCAL_LINKS_TO_UPDATE] 1017 1018 1019 @classmethod 1020 def delete_matching_record(cls, **filter_args): 1021 """Delete records matching the filter. 1022 1023 @param filter_args: Arguments for the django filter 1024 used to locate the record to delete. 1025 """ 1026 try: 1027 existing_record = cls.objects.get(**filter_args) 1028 except cls.DoesNotExist: 1029 return 1030 existing_record.delete() 1031 1032 1033 def _deserialize_local(self, data): 1034 """Set local attributes from a list of tuples. 1035 1036 @param data: List of tuples like returned by 1037 _split_local_from_foreign_values. 1038 """ 1039 if not data: 1040 return 1041 1042 for link, value in data: 1043 setattr(self, link, value) 1044 # Overwridden save() methods are prone to errors, so don't execute them. 1045 # This is because: 1046 # - the overwritten methods depend on ACL groups that don't yet exist 1047 # and don't handle errors 1048 # - the overwritten methods think this object already exists in the db 1049 # because the id is already set 1050 super(type(self), self).save() 1051 1052 1053 def _deserialize_relations(self, data): 1054 """Set foreign attributes from a list of tuples. 1055 1056 This deserialized the related objects using their own deserialize() 1057 function and then sets the relation. 1058 1059 @param data: List of tuples like returned by 1060 _split_local_from_foreign_values. 1061 """ 1062 for link, value in data: 1063 self._deserialize_relation(link, value) 1064 # See comment in _deserialize_local 1065 super(type(self), self).save() 1066 1067 1068 @classmethod 1069 def get_record(cls, data): 1070 """Retrieve a record with the data in the given input arg. 1071 1072 @param data: A dictionary containing the information to use in a query 1073 for data. If child models have different constraints of 1074 uniqueness they should override this model. 1075 1076 @return: An object with matching data. 1077 1078 @raises DoesNotExist: If a record with the given data doesn't exist. 1079 """ 1080 return cls.objects.get(id=data['id']) 1081 1082 1083 @classmethod 1084 def deserialize(cls, data): 1085 """Recursively deserializes and saves an object with it's dependencies. 1086 1087 This takes the result of the serialize method and creates objects 1088 in the database that are just like the original. 1089 1090 If an object of the same type with the same id already exists, it's 1091 local values will be left untouched, unless they are explicitly 1092 whitelisted in SERIALIZATION_LOCAL_LINKS_TO_UPDATE. 1093 1094 Deserialize will always recursively propagate to all related objects 1095 present in data though. 1096 I.e. this is necessary to add users to an already existing acl-group. 1097 1098 @param data: Representation of an object and its dependencies, as 1099 returned by serialize. 1100 1101 @returns: The object represented by data if it didn't exist before, 1102 otherwise the object that existed before and has the same type 1103 and id as the one described by data. 1104 """ 1105 if data is None: 1106 return None 1107 1108 local, related = cls._split_local_from_foreign_values(data) 1109 try: 1110 instance = cls.get_record(data) 1111 local = cls._filter_update_allowed_fields(local) 1112 except cls.DoesNotExist: 1113 instance = cls() 1114 1115 instance._deserialize_local(local) 1116 instance._deserialize_relations(related) 1117 1118 return instance 1119 1120 1121 def sanity_check_update_from_shard(self, shard, updated_serialized, 1122 *args, **kwargs): 1123 """Check if an update sent from a shard is legitimate. 1124 1125 @raises error.UnallowedRecordsSentToMaster if an update is not 1126 legitimate. 1127 """ 1128 raise NotImplementedError( 1129 'sanity_check_update_from_shard must be implemented by subclass %s ' 1130 'for type %s' % type(self)) 1131 1132 1133 @transaction.commit_on_success 1134 def update_from_serialized(self, serialized): 1135 """Updates local fields of an existing object from a serialized form. 1136 1137 This is different than the normal deserialize() in the way that it 1138 does update local values, which deserialize doesn't, but doesn't 1139 recursively propagate to related objects, which deserialize() does. 1140 1141 The use case of this function is to update job records on the master 1142 after the jobs have been executed on a slave, as the master is not 1143 interested in updates for users, labels, specialtasks, etc. 1144 1145 @param serialized: Representation of an object and its dependencies, as 1146 returned by serialize. 1147 1148 @raises ValueError: if serialized contains related objects, i.e. not 1149 only local fields. 1150 """ 1151 local, related = ( 1152 self._split_local_from_foreign_values(serialized)) 1153 if related: 1154 raise ValueError('Serialized must not contain foreign ' 1155 'objects: %s' % related) 1156 1157 self._deserialize_local(local) 1158 1159 1160 def custom_deserialize_relation(self, link, data): 1161 """Allows overriding the deserialization behaviour by subclasses.""" 1162 raise NotImplementedError( 1163 'custom_deserialize_relation must be implemented by subclass %s ' 1164 'for relation %s' % (type(self), link)) 1165 1166 1167 def _deserialize_relation(self, link, data): 1168 """Deserializes related objects and sets references on this object. 1169 1170 Relations that point to a list of objects are handled automatically. 1171 For many-to-one or one-to-one relations custom_deserialize_relation 1172 must be overridden by the subclass. 1173 1174 Related objects are deserialized using their deserialize() method. 1175 Thereby they and their dependencies are created if they don't exist 1176 and saved to the database. 1177 1178 @param link: Name of the relation. 1179 @param data: Serialized representation of the related object(s). 1180 This means a list of dictionaries for to-many relations, 1181 just a dictionary for to-one relations. 1182 """ 1183 field = getattr(self, link) 1184 1185 if field and hasattr(field, 'all'): 1186 self._deserialize_2m_relation(link, data, field.model) 1187 else: 1188 self.custom_deserialize_relation(link, data) 1189 1190 1191 def _deserialize_2m_relation(self, link, data, related_class): 1192 """Deserialize related objects for one to-many relationship. 1193 1194 @param link: Name of the relation. 1195 @param data: Serialized representation of the related objects. 1196 This is a list with of dictionaries. 1197 @param related_class: A class representing a django model, with which 1198 this class has a one-to-many relationship. 1199 """ 1200 relation_set = getattr(self, link) 1201 if related_class == self.get_attribute_model(): 1202 # When deserializing a model together with 1203 # its attributes, clear all the exising attributes to ensure 1204 # db consistency. Note 'update' won't be sufficient, as we also 1205 # want to remove any attributes that no longer exist in |data|. 1206 # 1207 # core_filters is a dictionary of filters, defines how 1208 # RelatedMangager would query for the 1-to-many relationship. E.g. 1209 # Host.objects.get( 1210 # id=20).hostattribute_set.core_filters = {host_id:20} 1211 # We use it to delete objects related to the current object. 1212 related_class.objects.filter(**relation_set.core_filters).delete() 1213 for serialized in data: 1214 relation_set.add(related_class.deserialize(serialized)) 1215 1216 1217 @classmethod 1218 def get_attribute_model(cls): 1219 """Return the attribute model. 1220 1221 Subclass with attribute-like model should override this to 1222 return the attribute model class. This method will be 1223 called by _deserialize_2m_relation to determine whether 1224 to clear the one-to-many relations first on deserialization of object. 1225 """ 1226 return None 1227 1228 1229class ModelWithInvalid(ModelExtensions): 1230 """ 1231 Overrides model methods save() and delete() to support invalidation in 1232 place of actual deletion. Subclasses must have a boolean "invalid" 1233 field. 1234 """ 1235 1236 def save(self, *args, **kwargs): 1237 first_time = (self.id is None) 1238 if first_time: 1239 # see if this object was previously added and invalidated 1240 my_name = getattr(self, self.name_field) 1241 filters = {self.name_field : my_name, 'invalid' : True} 1242 try: 1243 old_object = self.__class__.objects.get(**filters) 1244 self.resurrect_object(old_object) 1245 except self.DoesNotExist: 1246 # no existing object 1247 pass 1248 1249 super(ModelWithInvalid, self).save(*args, **kwargs) 1250 1251 1252 def resurrect_object(self, old_object): 1253 """ 1254 Called when self is about to be saved for the first time and is actually 1255 "undeleting" a previously deleted object. Can be overridden by 1256 subclasses to copy data as desired from the deleted entry (but this 1257 superclass implementation must normally be called). 1258 """ 1259 self.id = old_object.id 1260 1261 1262 def clean_object(self): 1263 """ 1264 This method is called when an object is marked invalid. 1265 Subclasses should override this to clean up relationships that 1266 should no longer exist if the object were deleted. 1267 """ 1268 pass 1269 1270 1271 def delete(self): 1272 self.invalid = self.invalid 1273 assert not self.invalid 1274 self.invalid = True 1275 self.save() 1276 self.clean_object() 1277 1278 1279 @classmethod 1280 def get_valid_manager(cls): 1281 return cls.valid_objects 1282 1283 1284 class Manipulator(object): 1285 """ 1286 Force default manipulators to look only at valid objects - 1287 otherwise they will match against invalid objects when checking 1288 uniqueness. 1289 """ 1290 @classmethod 1291 def _prepare(cls, model): 1292 super(ModelWithInvalid.Manipulator, cls)._prepare(model) 1293 cls.manager = model.valid_objects 1294 1295 1296class ModelWithAttributes(object): 1297 """ 1298 Mixin class for models that have an attribute model associated with them. 1299 The attribute model is assumed to have its value field named "value". 1300 """ 1301 1302 def _get_attribute_model_and_args(self, attribute): 1303 """ 1304 Subclasses should override this to return a tuple (attribute_model, 1305 keyword_args), where attribute_model is a model class and keyword_args 1306 is a dict of args to pass to attribute_model.objects.get() to get an 1307 instance of the given attribute on this object. 1308 """ 1309 raise NotImplementedError 1310 1311 1312 def _is_replaced_by_static_attribute(self, attribute): 1313 """ 1314 Subclasses could override this to indicate whether it has static 1315 attributes. 1316 """ 1317 return False 1318 1319 1320 def set_attribute(self, attribute, value): 1321 if self._is_replaced_by_static_attribute(attribute): 1322 raise error.UnmodifiableAttributeException( 1323 'Failed to set attribute "%s" for host "%s" since it ' 1324 'is static. Use go/chromeos-skylab-inventory-tools to ' 1325 'modify this attribute.' % (attribute, self.hostname)) 1326 1327 attribute_model, get_args = self._get_attribute_model_and_args( 1328 attribute) 1329 attribute_object, _ = attribute_model.objects.get_or_create(**get_args) 1330 attribute_object.value = value 1331 attribute_object.save() 1332 1333 1334 def delete_attribute(self, attribute): 1335 if self._is_replaced_by_static_attribute(attribute): 1336 raise error.UnmodifiableAttributeException( 1337 'Failed to delete attribute "%s" for host "%s" since it ' 1338 'is static. Use go/chromeos-skylab-inventory-tools to ' 1339 'modify this attribute.' % (attribute, self.hostname)) 1340 1341 attribute_model, get_args = self._get_attribute_model_and_args( 1342 attribute) 1343 try: 1344 attribute_model.objects.get(**get_args).delete() 1345 except attribute_model.DoesNotExist: 1346 pass 1347 1348 1349 def set_or_delete_attribute(self, attribute, value): 1350 if value is None: 1351 self.delete_attribute(attribute) 1352 else: 1353 self.set_attribute(attribute, value) 1354 1355 1356class ModelWithHashManager(dbmodels.Manager): 1357 """Manager for use with the ModelWithHash abstract model class""" 1358 1359 def create(self, **kwargs): 1360 raise Exception('ModelWithHash manager should use get_or_create() ' 1361 'instead of create()') 1362 1363 1364 def get_or_create(self, **kwargs): 1365 kwargs['the_hash'] = self.model._compute_hash(**kwargs) 1366 return super(ModelWithHashManager, self).get_or_create(**kwargs) 1367 1368 1369class ModelWithHash(dbmodels.Model): 1370 """Superclass with methods for dealing with a hash column""" 1371 1372 the_hash = dbmodels.CharField(max_length=40, unique=True) 1373 1374 objects = ModelWithHashManager() 1375 1376 class Meta: 1377 abstract = True 1378 1379 1380 @classmethod 1381 def _compute_hash(cls, **kwargs): 1382 raise NotImplementedError('Subclasses must override _compute_hash()') 1383 1384 1385 def save(self, force_insert=False, **kwargs): 1386 """Prevents saving the model in most cases 1387 1388 We want these models to be immutable, so the generic save() operation 1389 will not work. These models should be instantiated through their the 1390 model.objects.get_or_create() method instead. 1391 1392 The exception is that save(force_insert=True) will be allowed, since 1393 that creates a new row. However, the preferred way to make instances of 1394 these models is through the get_or_create() method. 1395 """ 1396 if not force_insert: 1397 # Allow a forced insert to happen; if it's a duplicate, the unique 1398 # constraint will catch it later anyways 1399 raise Exception('ModelWithHash is immutable') 1400 super(ModelWithHash, self).save(force_insert=force_insert, **kwargs) 1401