1""" 2Extensions to Django's model logic. 3""" 4 5import django.core.exceptions 6from django.db import connection 7from django.db import connections 8from django.db import models as dbmodels 9from django.db import transaction 10from django.db.models.sql import query 11import django.db.models.sql.where 12 13from autotest_lib.client.common_lib import error 14from autotest_lib.frontend.afe import rdb_model_extensions 15 16 17class ValidationError(django.core.exceptions.ValidationError): 18 """\ 19 Data validation error in adding or updating an object. The associated 20 value is a dictionary mapping field names to error strings. 21 """ 22 23def _quote_name(name): 24 """Shorthand for connection.ops.quote_name().""" 25 return connection.ops.quote_name(name) 26 27 28class LeasedHostManager(dbmodels.Manager): 29 """Query manager for unleased, unlocked hosts. 30 """ 31 def get_query_set(self): 32 return (super(LeasedHostManager, self).get_query_set().filter( 33 leased=0, locked=0)) 34 35 36class ExtendedManager(dbmodels.Manager): 37 """\ 38 Extended manager supporting subquery filtering. 39 """ 40 41 class CustomQuery(query.Query): 42 def __init__(self, *args, **kwargs): 43 super(ExtendedManager.CustomQuery, self).__init__(*args, **kwargs) 44 self._custom_joins = [] 45 46 47 def clone(self, klass=None, **kwargs): 48 obj = super(ExtendedManager.CustomQuery, self).clone(klass) 49 obj._custom_joins = list(self._custom_joins) 50 return obj 51 52 53 def combine(self, rhs, connector): 54 super(ExtendedManager.CustomQuery, self).combine(rhs, connector) 55 if hasattr(rhs, '_custom_joins'): 56 self._custom_joins.extend(rhs._custom_joins) 57 58 59 def add_custom_join(self, table, condition, join_type, 60 condition_values=(), alias=None): 61 if alias is None: 62 alias = table 63 join_dict = dict(table=table, 64 condition=condition, 65 condition_values=condition_values, 66 join_type=join_type, 67 alias=alias) 68 self._custom_joins.append(join_dict) 69 70 71 @classmethod 72 def convert_query(self, query_set): 73 """ 74 Convert the query set's "query" attribute to a CustomQuery. 75 """ 76 # Make a copy of the query set 77 query_set = query_set.all() 78 query_set.query = query_set.query.clone( 79 klass=ExtendedManager.CustomQuery, 80 _custom_joins=[]) 81 return query_set 82 83 84 class _WhereClause(object): 85 """Object allowing us to inject arbitrary SQL into Django queries. 86 87 By using this instead of extra(where=...), we can still freely combine 88 queries with & and |. 89 """ 90 def __init__(self, clause, values=()): 91 self._clause = clause 92 self._values = values 93 94 95 def as_sql(self, qn=None, connection=None): 96 return self._clause, self._values 97 98 99 def relabel_aliases(self, change_map): 100 return 101 102 103 def add_join(self, query_set, join_table, join_key, join_condition='', 104 join_condition_values=(), join_from_key=None, alias=None, 105 suffix='', exclude=False, force_left_join=False): 106 """Add a join to query_set. 107 108 Join looks like this: 109 (INNER|LEFT) JOIN <join_table> AS <alias> 110 ON (<this table>.<join_from_key> = <join_table>.<join_key> 111 and <join_condition>) 112 113 @param join_table table to join to 114 @param join_key field referencing back to this model to use for the join 115 @param join_condition extra condition for the ON clause of the join 116 @param join_condition_values values to substitute into join_condition 117 @param join_from_key column on this model to join from. 118 @param alias alias to use for for join 119 @param suffix suffix to add to join_table for the join alias, if no 120 alias is provided 121 @param exclude if true, exclude rows that match this join (will use a 122 LEFT OUTER JOIN and an appropriate WHERE condition) 123 @param force_left_join - if true, a LEFT OUTER JOIN will be used 124 instead of an INNER JOIN regardless of other options 125 """ 126 join_from_table = query_set.model._meta.db_table 127 if join_from_key is None: 128 join_from_key = self.model._meta.pk.name 129 if alias is None: 130 alias = join_table + suffix 131 full_join_key = _quote_name(alias) + '.' + _quote_name(join_key) 132 full_join_condition = '%s = %s.%s' % (full_join_key, 133 _quote_name(join_from_table), 134 _quote_name(join_from_key)) 135 if join_condition: 136 full_join_condition += ' AND (' + join_condition + ')' 137 if exclude or force_left_join: 138 join_type = query_set.query.LOUTER 139 else: 140 join_type = query_set.query.INNER 141 142 query_set = self.CustomQuery.convert_query(query_set) 143 query_set.query.add_custom_join(join_table, 144 full_join_condition, 145 join_type, 146 condition_values=join_condition_values, 147 alias=alias) 148 149 if exclude: 150 query_set = query_set.extra(where=[full_join_key + ' IS NULL']) 151 152 return query_set 153 154 155 def _info_for_many_to_one_join(self, field, join_to_query, alias): 156 """ 157 @param field: the ForeignKey field on the related model 158 @param join_to_query: the query over the related model that we're 159 joining to 160 @param alias: alias of joined table 161 """ 162 info = {} 163 rhs_table = join_to_query.model._meta.db_table 164 info['rhs_table'] = rhs_table 165 info['rhs_column'] = field.column 166 info['lhs_column'] = field.rel.get_related_field().column 167 rhs_where = join_to_query.query.where 168 rhs_where.relabel_aliases({rhs_table: alias}) 169 compiler = join_to_query.query.get_compiler(using=join_to_query.db) 170 initial_clause, values = compiler.as_sql() 171 # initial_clause is compiled from `join_to_query`, which is a SELECT 172 # query returns at most one record. For it to be used in WHERE clause, 173 # it must be converted to a boolean value using EXISTS. 174 all_clauses = ('EXISTS (%s)' % initial_clause,) 175 if hasattr(join_to_query.query, 'extra_where'): 176 all_clauses += join_to_query.query.extra_where 177 info['where_clause'] = ( 178 ' AND '.join('(%s)' % clause for clause in all_clauses)) 179 info['values'] = values 180 return info 181 182 183 def _info_for_many_to_many_join(self, m2m_field, join_to_query, alias, 184 m2m_is_on_this_model): 185 """ 186 @param m2m_field: a Django field representing the M2M relationship. 187 It uses a pivot table with the following structure: 188 this model table <---> M2M pivot table <---> joined model table 189 @param join_to_query: the query over the related model that we're 190 joining to. 191 @param alias: alias of joined table 192 """ 193 if m2m_is_on_this_model: 194 # referenced field on this model 195 lhs_id_field = self.model._meta.pk 196 # foreign key on the pivot table referencing lhs_id_field 197 m2m_lhs_column = m2m_field.m2m_column_name() 198 # foreign key on the pivot table referencing rhd_id_field 199 m2m_rhs_column = m2m_field.m2m_reverse_name() 200 # referenced field on related model 201 rhs_id_field = m2m_field.rel.get_related_field() 202 else: 203 lhs_id_field = m2m_field.rel.get_related_field() 204 m2m_lhs_column = m2m_field.m2m_reverse_name() 205 m2m_rhs_column = m2m_field.m2m_column_name() 206 rhs_id_field = join_to_query.model._meta.pk 207 208 info = {} 209 info['rhs_table'] = m2m_field.m2m_db_table() 210 info['rhs_column'] = m2m_lhs_column 211 info['lhs_column'] = lhs_id_field.column 212 213 # select the ID of related models relevant to this join. we can only do 214 # a single join, so we need to gather this information up front and 215 # include it in the join condition. 216 rhs_ids = join_to_query.values_list(rhs_id_field.attname, flat=True) 217 assert len(rhs_ids) == 1, ('Many-to-many custom field joins can only ' 218 'match a single related object.') 219 rhs_id = rhs_ids[0] 220 221 info['where_clause'] = '%s.%s = %s' % (_quote_name(alias), 222 _quote_name(m2m_rhs_column), 223 rhs_id) 224 info['values'] = () 225 return info 226 227 228 def join_custom_field(self, query_set, join_to_query, alias, 229 left_join=True): 230 """Join to a related model to create a custom field in the given query. 231 232 This method is used to construct a custom field on the given query based 233 on a many-valued relationsip. join_to_query should be a simple query 234 (no joins) on the related model which returns at most one related row 235 per instance of this model. 236 237 For many-to-one relationships, the joined table contains the matching 238 row from the related model it one is related, NULL otherwise. 239 240 For many-to-many relationships, the joined table contains the matching 241 row if it's related, NULL otherwise. 242 """ 243 relationship_type, field = self.determine_relationship( 244 join_to_query.model) 245 246 if relationship_type == self.MANY_TO_ONE: 247 info = self._info_for_many_to_one_join(field, join_to_query, alias) 248 elif relationship_type == self.M2M_ON_RELATED_MODEL: 249 info = self._info_for_many_to_many_join( 250 m2m_field=field, join_to_query=join_to_query, alias=alias, 251 m2m_is_on_this_model=False) 252 elif relationship_type ==self.M2M_ON_THIS_MODEL: 253 info = self._info_for_many_to_many_join( 254 m2m_field=field, join_to_query=join_to_query, alias=alias, 255 m2m_is_on_this_model=True) 256 257 return self.add_join(query_set, info['rhs_table'], info['rhs_column'], 258 join_from_key=info['lhs_column'], 259 join_condition=info['where_clause'], 260 join_condition_values=info['values'], 261 alias=alias, 262 force_left_join=left_join) 263 264 265 def add_where(self, query_set, where, values=()): 266 query_set = query_set.all() 267 query_set.query.where.add(self._WhereClause(where, values), 268 django.db.models.sql.where.AND) 269 return query_set 270 271 272 def _get_quoted_field(self, table, field): 273 return _quote_name(table) + '.' + _quote_name(field) 274 275 276 def get_key_on_this_table(self, key_field=None): 277 if key_field is None: 278 # default to primary key 279 key_field = self.model._meta.pk.column 280 return self._get_quoted_field(self.model._meta.db_table, key_field) 281 282 283 def escape_user_sql(self, sql): 284 return sql.replace('%', '%%') 285 286 287 def _custom_select_query(self, query_set, selects): 288 """Execute a custom select query. 289 290 @param query_set: query set as returned by query_objects. 291 @param selects: Tables/Columns to select, e.g. tko_test_labels_list.id. 292 293 @returns: Result of the query as returned by cursor.fetchall(). 294 """ 295 compiler = query_set.query.get_compiler(using=query_set.db) 296 sql, params = compiler.as_sql() 297 from_ = sql[sql.find(' FROM'):] 298 299 if query_set.query.distinct: 300 distinct = 'DISTINCT ' 301 else: 302 distinct = '' 303 304 sql_query = ('SELECT ' + distinct + ','.join(selects) + from_) 305 # Chose the connection that's responsible for this type of object 306 cursor = connections[query_set.db].cursor() 307 cursor.execute(sql_query, params) 308 return cursor.fetchall() 309 310 311 def _is_relation_to(self, field, model_class): 312 return field.rel and field.rel.to is model_class 313 314 315 MANY_TO_ONE = object() 316 M2M_ON_RELATED_MODEL = object() 317 M2M_ON_THIS_MODEL = object() 318 319 def determine_relationship(self, related_model): 320 """ 321 Determine the relationship between this model and related_model. 322 323 related_model must have some sort of many-valued relationship to this 324 manager's model. 325 @returns (relationship_type, field), where relationship_type is one of 326 MANY_TO_ONE, M2M_ON_RELATED_MODEL, M2M_ON_THIS_MODEL, and field 327 is the Django field object for the relationship. 328 """ 329 # look for a foreign key field on related_model relating to this model 330 for field in related_model._meta.fields: 331 if self._is_relation_to(field, self.model): 332 return self.MANY_TO_ONE, field 333 334 # look for an M2M field on related_model relating to this model 335 for field in related_model._meta.many_to_many: 336 if self._is_relation_to(field, self.model): 337 return self.M2M_ON_RELATED_MODEL, field 338 339 # maybe this model has the many-to-many field 340 for field in self.model._meta.many_to_many: 341 if self._is_relation_to(field, related_model): 342 return self.M2M_ON_THIS_MODEL, field 343 344 raise ValueError('%s has no relation to %s' % 345 (related_model, self.model)) 346 347 348 def _get_pivot_iterator(self, base_objects_by_id, related_model): 349 """ 350 Determine the relationship between this model and related_model, and 351 return a pivot iterator. 352 @param base_objects_by_id: dict of instances of this model indexed by 353 their IDs 354 @returns a pivot iterator, which yields a tuple (base_object, 355 related_object) for each relationship between a base object and a 356 related object. all base_object instances come from base_objects_by_id. 357 Note -- this depends on Django model internals. 358 """ 359 relationship_type, field = self.determine_relationship(related_model) 360 if relationship_type == self.MANY_TO_ONE: 361 return self._many_to_one_pivot(base_objects_by_id, 362 related_model, field) 363 elif relationship_type == self.M2M_ON_RELATED_MODEL: 364 return self._many_to_many_pivot( 365 base_objects_by_id, related_model, field.m2m_db_table(), 366 field.m2m_reverse_name(), field.m2m_column_name()) 367 else: 368 assert relationship_type == self.M2M_ON_THIS_MODEL 369 return self._many_to_many_pivot( 370 base_objects_by_id, related_model, field.m2m_db_table(), 371 field.m2m_column_name(), field.m2m_reverse_name()) 372 373 374 def _many_to_one_pivot(self, base_objects_by_id, related_model, 375 foreign_key_field): 376 """ 377 @returns a pivot iterator - see _get_pivot_iterator() 378 """ 379 filter_data = {foreign_key_field.name + '__pk__in': 380 base_objects_by_id.keys()} 381 for related_object in related_model.objects.filter(**filter_data): 382 # lookup base object in the dict, rather than grabbing it from the 383 # related object. we need to return instances from the dict, not 384 # fresh instances of the same models (and grabbing model instances 385 # from the related models incurs a DB query each time). 386 base_object_id = getattr(related_object, foreign_key_field.attname) 387 base_object = base_objects_by_id[base_object_id] 388 yield base_object, related_object 389 390 391 def _query_pivot_table(self, base_objects_by_id, pivot_table, 392 pivot_from_field, pivot_to_field, related_model): 393 """ 394 @param id_list list of IDs of self.model objects to include 395 @param pivot_table the name of the pivot table 396 @param pivot_from_field a field name on pivot_table referencing 397 self.model 398 @param pivot_to_field a field name on pivot_table referencing the 399 related model. 400 @param related_model the related model 401 402 @returns pivot list of IDs (base_id, related_id) 403 """ 404 query = """ 405 SELECT %(from_field)s, %(to_field)s 406 FROM %(table)s 407 WHERE %(from_field)s IN (%(id_list)s) 408 """ % dict(from_field=pivot_from_field, 409 to_field=pivot_to_field, 410 table=pivot_table, 411 id_list=','.join(str(id_) for id_ 412 in base_objects_by_id.iterkeys())) 413 414 # Chose the connection that's responsible for this type of object 415 # The databases for related_model and the current model will always 416 # be the same, related_model is just easier to obtain here because 417 # self is only a ExtendedManager, not the object. 418 cursor = connections[related_model.objects.db].cursor() 419 cursor.execute(query) 420 return cursor.fetchall() 421 422 423 def _many_to_many_pivot(self, base_objects_by_id, related_model, 424 pivot_table, pivot_from_field, pivot_to_field): 425 """ 426 @param pivot_table: see _query_pivot_table 427 @param pivot_from_field: see _query_pivot_table 428 @param pivot_to_field: see _query_pivot_table 429 @returns a pivot iterator - see _get_pivot_iterator() 430 """ 431 id_pivot = self._query_pivot_table(base_objects_by_id, pivot_table, 432 pivot_from_field, pivot_to_field, 433 related_model) 434 435 all_related_ids = list(set(related_id for base_id, related_id 436 in id_pivot)) 437 related_objects_by_id = related_model.objects.in_bulk(all_related_ids) 438 439 for base_id, related_id in id_pivot: 440 yield base_objects_by_id[base_id], related_objects_by_id[related_id] 441 442 443 def populate_relationships(self, base_objects, related_model, 444 related_list_name): 445 """ 446 For each instance of this model in base_objects, add a field named 447 related_list_name listing all the related objects of type related_model. 448 related_model must be in a many-to-one or many-to-many relationship with 449 this model. 450 @param base_objects - list of instances of this model 451 @param related_model - model class related to this model 452 @param related_list_name - attribute name in which to store the related 453 object list. 454 """ 455 if not base_objects: 456 # if we don't bail early, we'll get a SQL error later 457 return 458 459 # The default maximum value of a host parameter number in SQLite is 999. 460 # Exceed this will get a DatabaseError later. 461 batch_size = 900 462 for i in xrange(0, len(base_objects), batch_size): 463 base_objects_batch = base_objects[i:i + batch_size] 464 base_objects_by_id = dict((base_object._get_pk_val(), base_object) 465 for base_object in base_objects_batch) 466 pivot_iterator = self._get_pivot_iterator(base_objects_by_id, 467 related_model) 468 469 for base_object in base_objects_batch: 470 setattr(base_object, related_list_name, []) 471 472 for base_object, related_object in pivot_iterator: 473 getattr(base_object, related_list_name).append(related_object) 474 475 476class ModelWithInvalidQuerySet(dbmodels.query.QuerySet): 477 """ 478 QuerySet that handles delete() properly for models with an "invalid" bit 479 """ 480 def delete(self): 481 for model in self: 482 model.delete() 483 484 485class ModelWithInvalidManager(ExtendedManager): 486 """ 487 Manager for objects with an "invalid" bit 488 """ 489 def get_query_set(self): 490 return ModelWithInvalidQuerySet(self.model) 491 492 493class ValidObjectsManager(ModelWithInvalidManager): 494 """ 495 Manager returning only objects with invalid=False. 496 """ 497 def get_query_set(self): 498 queryset = super(ValidObjectsManager, self).get_query_set() 499 return queryset.filter(invalid=False) 500 501 502class ModelExtensions(rdb_model_extensions.ModelValidators): 503 """\ 504 Mixin with convenience functions for models, built on top of 505 the model validators in rdb_model_extensions. 506 """ 507 # TODO: at least some of these functions really belong in a custom 508 # Manager class 509 510 511 SERIALIZATION_LINKS_TO_FOLLOW = set() 512 """ 513 To be able to send jobs and hosts to shards, it's necessary to find their 514 dependencies. 515 The most generic approach for this would be to traverse all relationships 516 to other objects recursively. This would list all objects that are related 517 in any way. 518 But this approach finds too many objects: If a host should be transferred, 519 all it's relationships would be traversed. This would find an acl group. 520 If then the acl group's relationships are traversed, the relationship 521 would be followed backwards and many other hosts would be found. 522 523 This mapping tells that algorithm which relations to follow explicitly. 524 """ 525 526 527 SERIALIZATION_LINKS_TO_KEEP = set() 528 """This set stores foreign keys which we don't want to follow, but 529 still want to include in the serialized dictionary. For 530 example, we follow the relationship `Host.hostattribute_set`, 531 but we do not want to follow `HostAttributes.host_id` back to 532 to Host, which would otherwise lead to a circle. However, we still 533 like to serialize HostAttribute.`host_id`.""" 534 535 SERIALIZATION_LOCAL_LINKS_TO_UPDATE = set() 536 """ 537 On deserializion, if the object to persist already exists, local fields 538 will only be updated, if their name is in this set. 539 """ 540 541 542 @classmethod 543 def convert_human_readable_values(cls, data, to_human_readable=False): 544 """\ 545 Performs conversions on user-supplied field data, to make it 546 easier for users to pass human-readable data. 547 548 For all fields that have choice sets, convert their values 549 from human-readable strings to enum values, if necessary. This 550 allows users to pass strings instead of the corresponding 551 integer values. 552 553 For all foreign key fields, call smart_get with the supplied 554 data. This allows the user to pass either an ID value or 555 the name of the object as a string. 556 557 If to_human_readable=True, perform the inverse - i.e. convert 558 numeric values to human readable values. 559 560 This method modifies data in-place. 561 """ 562 field_dict = cls.get_field_dict() 563 for field_name in data: 564 if field_name not in field_dict or data[field_name] is None: 565 continue 566 field_obj = field_dict[field_name] 567 # convert enum values 568 if field_obj.choices: 569 for choice_data in field_obj.choices: 570 # choice_data is (value, name) 571 if to_human_readable: 572 from_val, to_val = choice_data 573 else: 574 to_val, from_val = choice_data 575 if from_val == data[field_name]: 576 data[field_name] = to_val 577 break 578 # convert foreign key values 579 elif field_obj.rel: 580 dest_obj = field_obj.rel.to.smart_get(data[field_name], 581 valid_only=False) 582 if to_human_readable: 583 # parameterized_jobs do not have a name_field 584 if (field_name != 'parameterized_job' and 585 dest_obj.name_field is not None): 586 data[field_name] = getattr(dest_obj, 587 dest_obj.name_field) 588 else: 589 data[field_name] = dest_obj 590 591 592 593 594 def _validate_unique(self): 595 """\ 596 Validate that unique fields are unique. Django manipulators do 597 this too, but they're a huge pain to use manually. Trust me. 598 """ 599 errors = {} 600 cls = type(self) 601 field_dict = self.get_field_dict() 602 manager = cls.get_valid_manager() 603 for field_name, field_obj in field_dict.iteritems(): 604 if not field_obj.unique: 605 continue 606 607 value = getattr(self, field_name) 608 if value is None and field_obj.auto_created: 609 # don't bother checking autoincrement fields about to be 610 # generated 611 continue 612 613 existing_objs = manager.filter(**{field_name : value}) 614 num_existing = existing_objs.count() 615 616 if num_existing == 0: 617 continue 618 if num_existing == 1 and existing_objs[0].id == self.id: 619 continue 620 errors[field_name] = ( 621 'This value must be unique (%s)' % (value)) 622 return errors 623 624 625 def _validate(self): 626 """ 627 First coerces all fields on this instance to their proper Python types. 628 Then runs validation on every field. Returns a dictionary of 629 field_name -> error_list. 630 631 Based on validate() from django.db.models.Model in Django 0.96, which 632 was removed in Django 1.0. It should reappear in a later version. See: 633 http://code.djangoproject.com/ticket/6845 634 """ 635 error_dict = {} 636 for f in self._meta.fields: 637 try: 638 python_value = f.to_python( 639 getattr(self, f.attname, f.get_default())) 640 except django.core.exceptions.ValidationError, e: 641 error_dict[f.name] = str(e) 642 continue 643 644 if not f.blank and not python_value: 645 error_dict[f.name] = 'This field is required.' 646 continue 647 648 setattr(self, f.attname, python_value) 649 650 return error_dict 651 652 653 def do_validate(self): 654 errors = self._validate() 655 unique_errors = self._validate_unique() 656 for field_name, error in unique_errors.iteritems(): 657 errors.setdefault(field_name, error) 658 if errors: 659 raise ValidationError(errors) 660 661 662 # actually (externally) useful methods follow 663 664 @classmethod 665 def add_object(cls, data={}, **kwargs): 666 """\ 667 Returns a new object created with the given data (a dictionary 668 mapping field names to values). Merges any extra keyword args 669 into data. 670 """ 671 data = dict(data) 672 data.update(kwargs) 673 data = cls.prepare_data_args(data) 674 cls.convert_human_readable_values(data) 675 data = cls.provide_default_values(data) 676 677 obj = cls(**data) 678 obj.do_validate() 679 obj.save() 680 return obj 681 682 683 def update_object(self, data={}, **kwargs): 684 """\ 685 Updates the object with the given data (a dictionary mapping 686 field names to values). Merges any extra keyword args into 687 data. 688 """ 689 data = dict(data) 690 data.update(kwargs) 691 data = self.prepare_data_args(data) 692 self.convert_human_readable_values(data) 693 for field_name, value in data.iteritems(): 694 setattr(self, field_name, value) 695 self.do_validate() 696 self.save() 697 698 699 # see query_objects() 700 _SPECIAL_FILTER_KEYS = ('query_start', 'query_limit', 'sort_by', 701 'extra_args', 'extra_where', 'no_distinct') 702 703 704 @classmethod 705 def _extract_special_params(cls, filter_data): 706 """ 707 @returns a tuple of dicts (special_params, regular_filters), where 708 special_params contains the parameters we handle specially and 709 regular_filters is the remaining data to be handled by Django. 710 """ 711 regular_filters = dict(filter_data) 712 special_params = {} 713 for key in cls._SPECIAL_FILTER_KEYS: 714 if key in regular_filters: 715 special_params[key] = regular_filters.pop(key) 716 return special_params, regular_filters 717 718 719 @classmethod 720 def apply_presentation(cls, query, filter_data): 721 """ 722 Apply presentation parameters -- sorting and paging -- to the given 723 query. 724 @returns new query with presentation applied 725 """ 726 special_params, _ = cls._extract_special_params(filter_data) 727 sort_by = special_params.get('sort_by', None) 728 if sort_by: 729 assert isinstance(sort_by, list) or isinstance(sort_by, tuple) 730 query = query.extra(order_by=sort_by) 731 732 query_start = special_params.get('query_start', None) 733 query_limit = special_params.get('query_limit', None) 734 if query_start is not None: 735 if query_limit is None: 736 raise ValueError('Cannot pass query_start without query_limit') 737 # query_limit is passed as a page size 738 query_limit += query_start 739 return query[query_start:query_limit] 740 741 742 @classmethod 743 def query_objects(cls, filter_data, valid_only=True, initial_query=None, 744 apply_presentation=True): 745 """\ 746 Returns a QuerySet object for querying the given model_class 747 with the given filter_data. Optional special arguments in 748 filter_data include: 749 -query_start: index of first return to return 750 -query_limit: maximum number of results to return 751 -sort_by: list of fields to sort on. prefixing a '-' onto a 752 field name changes the sort to descending order. 753 -extra_args: keyword args to pass to query.extra() (see Django 754 DB layer documentation) 755 -extra_where: extra WHERE clause to append 756 -no_distinct: if True, a DISTINCT will not be added to the SELECT 757 """ 758 special_params, regular_filters = cls._extract_special_params( 759 filter_data) 760 761 if initial_query is None: 762 if valid_only: 763 initial_query = cls.get_valid_manager() 764 else: 765 initial_query = cls.objects 766 767 query = initial_query.filter(**regular_filters) 768 769 use_distinct = not special_params.get('no_distinct', False) 770 if use_distinct: 771 query = query.distinct() 772 773 extra_args = special_params.get('extra_args', {}) 774 extra_where = special_params.get('extra_where', None) 775 if extra_where: 776 # escape %'s 777 extra_where = cls.objects.escape_user_sql(extra_where) 778 extra_args.setdefault('where', []).append(extra_where) 779 if extra_args: 780 query = query.extra(**extra_args) 781 # TODO: Use readonly connection for these queries. 782 # This has been disabled, because it's not used anyway, as the 783 # configured readonly user is the same as the real user anyway. 784 785 if apply_presentation: 786 query = cls.apply_presentation(query, filter_data) 787 788 return query 789 790 791 @classmethod 792 def query_count(cls, filter_data, initial_query=None): 793 """\ 794 Like query_objects, but retreive only the count of results. 795 """ 796 filter_data.pop('query_start', None) 797 filter_data.pop('query_limit', None) 798 query = cls.query_objects(filter_data, initial_query=initial_query) 799 return query.count() 800 801 802 @classmethod 803 def clean_object_dicts(cls, field_dicts): 804 """\ 805 Take a list of dicts corresponding to object (as returned by 806 query.values()) and clean the data to be more suitable for 807 returning to the user. 808 """ 809 for field_dict in field_dicts: 810 cls.clean_foreign_keys(field_dict) 811 cls._convert_booleans(field_dict) 812 cls.convert_human_readable_values(field_dict, 813 to_human_readable=True) 814 815 816 @classmethod 817 def list_objects(cls, filter_data, initial_query=None): 818 """\ 819 Like query_objects, but return a list of dictionaries. 820 """ 821 query = cls.query_objects(filter_data, initial_query=initial_query) 822 extra_fields = query.query.extra_select.keys() 823 field_dicts = [model_object.get_object_dict(extra_fields=extra_fields) 824 for model_object in query] 825 return field_dicts 826 827 828 @classmethod 829 def smart_get(cls, id_or_name, valid_only=True): 830 """\ 831 smart_get(integer) -> get object by ID 832 smart_get(string) -> get object by name_field 833 """ 834 if valid_only: 835 manager = cls.get_valid_manager() 836 else: 837 manager = cls.objects 838 839 if isinstance(id_or_name, (int, long)): 840 return manager.get(pk=id_or_name) 841 if isinstance(id_or_name, basestring) and hasattr(cls, 'name_field'): 842 return manager.get(**{cls.name_field : id_or_name}) 843 raise ValueError( 844 'Invalid positional argument: %s (%s)' % (id_or_name, 845 type(id_or_name))) 846 847 848 @classmethod 849 def smart_get_bulk(cls, id_or_name_list): 850 invalid_inputs = [] 851 result_objects = [] 852 for id_or_name in id_or_name_list: 853 try: 854 result_objects.append(cls.smart_get(id_or_name)) 855 except cls.DoesNotExist: 856 invalid_inputs.append(id_or_name) 857 if invalid_inputs: 858 raise cls.DoesNotExist('The following %ss do not exist: %s' 859 % (cls.__name__.lower(), 860 ', '.join(invalid_inputs))) 861 return result_objects 862 863 864 def get_object_dict(self, extra_fields=None): 865 """\ 866 Return a dictionary mapping fields to this object's values. @param 867 extra_fields: list of extra attribute names to include, in addition to 868 the fields defined on this object. 869 """ 870 fields = self.get_field_dict().keys() 871 if extra_fields: 872 fields += extra_fields 873 object_dict = dict((field_name, getattr(self, field_name)) 874 for field_name in fields) 875 self.clean_object_dicts([object_dict]) 876 self._postprocess_object_dict(object_dict) 877 return object_dict 878 879 880 def _postprocess_object_dict(self, object_dict): 881 """For subclasses to override.""" 882 pass 883 884 885 @classmethod 886 def get_valid_manager(cls): 887 return cls.objects 888 889 890 def _record_attributes(self, attributes): 891 """ 892 See on_attribute_changed. 893 """ 894 assert not isinstance(attributes, basestring) 895 self._recorded_attributes = dict((attribute, getattr(self, attribute)) 896 for attribute in attributes) 897 898 899 def _check_for_updated_attributes(self): 900 """ 901 See on_attribute_changed. 902 """ 903 for attribute, original_value in self._recorded_attributes.iteritems(): 904 new_value = getattr(self, attribute) 905 if original_value != new_value: 906 self.on_attribute_changed(attribute, original_value) 907 self._record_attributes(self._recorded_attributes.keys()) 908 909 910 def on_attribute_changed(self, attribute, old_value): 911 """ 912 Called whenever an attribute is updated. To be overridden. 913 914 To use this method, you must: 915 * call _record_attributes() from __init__() (after making the super 916 call) with a list of attributes for which you want to be notified upon 917 change. 918 * call _check_for_updated_attributes() from save(). 919 """ 920 pass 921 922 923 def serialize(self, include_dependencies=True): 924 """Serializes the object with dependencies. 925 926 The variable SERIALIZATION_LINKS_TO_FOLLOW defines which dependencies 927 this function will serialize with the object. 928 929 @param include_dependencies: Whether or not to follow relations to 930 objects this object depends on. 931 This parameter is used when uploading 932 jobs from a shard to the main, as the 933 main already has all the dependent 934 objects. 935 936 @returns: Dictionary representation of the object. 937 """ 938 serialized = {} 939 for field in self._meta.concrete_model._meta.local_fields: 940 if field.rel is None: 941 serialized[field.name] = field._get_val_from_obj(self) 942 elif field.name in self.SERIALIZATION_LINKS_TO_KEEP: 943 # attname will contain "_id" suffix for foreign keys, 944 # e.g. HostAttribute.host will be serialized as 'host_id'. 945 # Use it for easy deserialization. 946 serialized[field.attname] = field._get_val_from_obj(self) 947 948 if include_dependencies: 949 for link in self.SERIALIZATION_LINKS_TO_FOLLOW: 950 serialized[link] = self._serialize_relation(link) 951 952 return serialized 953 954 955 def _serialize_relation(self, link): 956 """Serializes dependent objects given the name of the relation. 957 958 @param link: Name of the relation to take objects from. 959 960 @returns For To-Many relationships a list of the serialized related 961 objects, for To-One relationships the serialized related object. 962 """ 963 try: 964 attr = getattr(self, link) 965 except AttributeError: 966 # One-To-One relationships that point to None may raise this 967 return None 968 969 if attr is None: 970 return None 971 if hasattr(attr, 'all'): 972 return [obj.serialize() for obj in attr.all()] 973 return attr.serialize() 974 975 976 @classmethod 977 def _split_local_from_foreign_values(cls, data): 978 """This splits local from foreign values in a serialized object. 979 980 @param data: The serialized object. 981 982 @returns A tuple of two lists, both containing tuples in the form 983 (link_name, link_value). The first list contains all links 984 for local fields, the second one contains those for foreign 985 fields/objects. 986 """ 987 links_to_local_values, links_to_related_values = [], [] 988 for link, value in data.iteritems(): 989 if link in cls.SERIALIZATION_LINKS_TO_FOLLOW: 990 # It's a foreign key 991 links_to_related_values.append((link, value)) 992 else: 993 # It's a local attribute or a foreign key 994 # we don't want to follow. 995 links_to_local_values.append((link, value)) 996 return links_to_local_values, links_to_related_values 997 998 999 @classmethod 1000 def _filter_update_allowed_fields(cls, data): 1001 """Filters data and returns only files that updates are allowed on. 1002 1003 This is i.e. needed for syncing aborted bits from the main to shards. 1004 1005 Local links are only allowed to be updated, if they are in 1006 SERIALIZATION_LOCAL_LINKS_TO_UPDATE. 1007 Overwriting existing values is allowed in order to be able to sync i.e. 1008 the aborted bit from the main to a shard. 1009 1010 The allowlisting mechanism is in place to prevent overwriting local 1011 status: If all fields were overwritten, jobs would be completely be 1012 set back to their original (unstarted) state. 1013 1014 @param data: List with tuples of the form (link_name, link_value), as 1015 returned by _split_local_from_foreign_values. 1016 1017 @returns List of the same format as data, but only containing data for 1018 fields that updates are allowed on. 1019 """ 1020 return [pair for pair in data 1021 if pair[0] in cls.SERIALIZATION_LOCAL_LINKS_TO_UPDATE] 1022 1023 1024 @classmethod 1025 def delete_matching_record(cls, **filter_args): 1026 """Delete records matching the filter. 1027 1028 @param filter_args: Arguments for the django filter 1029 used to locate the record to delete. 1030 """ 1031 try: 1032 existing_record = cls.objects.get(**filter_args) 1033 except cls.DoesNotExist: 1034 return 1035 existing_record.delete() 1036 1037 1038 def _deserialize_local(self, data): 1039 """Set local attributes from a list of tuples. 1040 1041 @param data: List of tuples like returned by 1042 _split_local_from_foreign_values. 1043 """ 1044 if not data: 1045 return 1046 1047 for link, value in data: 1048 setattr(self, link, value) 1049 # Overwridden save() methods are prone to errors, so don't execute them. 1050 # This is because: 1051 # - the overwritten methods depend on ACL groups that don't yet exist 1052 # and don't handle errors 1053 # - the overwritten methods think this object already exists in the db 1054 # because the id is already set 1055 super(type(self), self).save() 1056 1057 1058 def _deserialize_relations(self, data): 1059 """Set foreign attributes from a list of tuples. 1060 1061 This deserialized the related objects using their own deserialize() 1062 function and then sets the relation. 1063 1064 @param data: List of tuples like returned by 1065 _split_local_from_foreign_values. 1066 """ 1067 for link, value in data: 1068 self._deserialize_relation(link, value) 1069 # See comment in _deserialize_local 1070 super(type(self), self).save() 1071 1072 1073 @classmethod 1074 def get_record(cls, data): 1075 """Retrieve a record with the data in the given input arg. 1076 1077 @param data: A dictionary containing the information to use in a query 1078 for data. If child models have different constraints of 1079 uniqueness they should override this model. 1080 1081 @return: An object with matching data. 1082 1083 @raises DoesNotExist: If a record with the given data doesn't exist. 1084 """ 1085 return cls.objects.get(id=data['id']) 1086 1087 1088 @classmethod 1089 def deserialize(cls, data): 1090 """Recursively deserializes and saves an object with it's dependencies. 1091 1092 This takes the result of the serialize method and creates objects 1093 in the database that are just like the original. 1094 1095 If an object of the same type with the same id already exists, it's 1096 local values will be left untouched, unless they are explicitly 1097 allowlisted in SERIALIZATION_LOCAL_LINKS_TO_UPDATE. 1098 1099 Deserialize will always recursively propagate to all related objects 1100 present in data though. 1101 I.e. this is necessary to add users to an already existing acl-group. 1102 1103 @param data: Representation of an object and its dependencies, as 1104 returned by serialize. 1105 1106 @returns: The object represented by data if it didn't exist before, 1107 otherwise the object that existed before and has the same type 1108 and id as the one described by data. 1109 """ 1110 if data is None: 1111 return None 1112 1113 local, related = cls._split_local_from_foreign_values(data) 1114 try: 1115 instance = cls.get_record(data) 1116 local = cls._filter_update_allowed_fields(local) 1117 except cls.DoesNotExist: 1118 instance = cls() 1119 1120 instance._deserialize_local(local) 1121 instance._deserialize_relations(related) 1122 1123 return instance 1124 1125 1126 def sanity_check_update_from_shard(self, shard, updated_serialized, 1127 *args, **kwargs): 1128 """Check if an update sent from a shard is legitimate. 1129 1130 @raises error.UnallowedRecordsSentToMain if an update is not 1131 legitimate. 1132 """ 1133 raise NotImplementedError( 1134 'sanity_check_update_from_shard must be implemented by subclass %s ' 1135 'for type %s' % type(self)) 1136 1137 1138 @transaction.commit_on_success 1139 def update_from_serialized(self, serialized): 1140 """Updates local fields of an existing object from a serialized form. 1141 1142 This is different than the normal deserialize() in the way that it 1143 does update local values, which deserialize doesn't, but doesn't 1144 recursively propagate to related objects, which deserialize() does. 1145 1146 The use case of this function is to update job records on the main 1147 after the jobs have been executed on a shard, as the main is not 1148 interested in updates for users, labels, specialtasks, etc. 1149 1150 @param serialized: Representation of an object and its dependencies, as 1151 returned by serialize. 1152 1153 @raises ValueError: if serialized contains related objects, i.e. not 1154 only local fields. 1155 """ 1156 local, related = ( 1157 self._split_local_from_foreign_values(serialized)) 1158 if related: 1159 raise ValueError('Serialized must not contain foreign ' 1160 'objects: %s' % related) 1161 1162 self._deserialize_local(local) 1163 1164 1165 def custom_deserialize_relation(self, link, data): 1166 """Allows overriding the deserialization behaviour by subclasses.""" 1167 raise NotImplementedError( 1168 'custom_deserialize_relation must be implemented by subclass %s ' 1169 'for relation %s' % (type(self), link)) 1170 1171 1172 def _deserialize_relation(self, link, data): 1173 """Deserializes related objects and sets references on this object. 1174 1175 Relations that point to a list of objects are handled automatically. 1176 For many-to-one or one-to-one relations custom_deserialize_relation 1177 must be overridden by the subclass. 1178 1179 Related objects are deserialized using their deserialize() method. 1180 Thereby they and their dependencies are created if they don't exist 1181 and saved to the database. 1182 1183 @param link: Name of the relation. 1184 @param data: Serialized representation of the related object(s). 1185 This means a list of dictionaries for to-many relations, 1186 just a dictionary for to-one relations. 1187 """ 1188 field = getattr(self, link) 1189 1190 if field and hasattr(field, 'all'): 1191 self._deserialize_2m_relation(link, data, field.model) 1192 else: 1193 self.custom_deserialize_relation(link, data) 1194 1195 1196 def _deserialize_2m_relation(self, link, data, related_class): 1197 """Deserialize related objects for one to-many relationship. 1198 1199 @param link: Name of the relation. 1200 @param data: Serialized representation of the related objects. 1201 This is a list with of dictionaries. 1202 @param related_class: A class representing a django model, with which 1203 this class has a one-to-many relationship. 1204 """ 1205 relation_set = getattr(self, link) 1206 if related_class == self.get_attribute_model(): 1207 # When deserializing a model together with 1208 # its attributes, clear all the exising attributes to ensure 1209 # db consistency. Note 'update' won't be sufficient, as we also 1210 # want to remove any attributes that no longer exist in |data|. 1211 # 1212 # core_filters is a dictionary of filters, defines how 1213 # RelatedMangager would query for the 1-to-many relationship. E.g. 1214 # Host.objects.get( 1215 # id=20).hostattribute_set.core_filters = {host_id:20} 1216 # We use it to delete objects related to the current object. 1217 related_class.objects.filter(**relation_set.core_filters).delete() 1218 for serialized in data: 1219 relation_set.add(related_class.deserialize(serialized)) 1220 1221 1222 @classmethod 1223 def get_attribute_model(cls): 1224 """Return the attribute model. 1225 1226 Subclass with attribute-like model should override this to 1227 return the attribute model class. This method will be 1228 called by _deserialize_2m_relation to determine whether 1229 to clear the one-to-many relations first on deserialization of object. 1230 """ 1231 return None 1232 1233 1234class ModelWithInvalid(ModelExtensions): 1235 """ 1236 Overrides model methods save() and delete() to support invalidation in 1237 place of actual deletion. Subclasses must have a boolean "invalid" 1238 field. 1239 """ 1240 1241 def save(self, *args, **kwargs): 1242 first_time = (self.id is None) 1243 if first_time: 1244 # see if this object was previously added and invalidated 1245 my_name = getattr(self, self.name_field) 1246 filters = {self.name_field : my_name, 'invalid' : True} 1247 try: 1248 old_object = self.__class__.objects.get(**filters) 1249 self.resurrect_object(old_object) 1250 except self.DoesNotExist: 1251 # no existing object 1252 pass 1253 1254 super(ModelWithInvalid, self).save(*args, **kwargs) 1255 1256 1257 def resurrect_object(self, old_object): 1258 """ 1259 Called when self is about to be saved for the first time and is actually 1260 "undeleting" a previously deleted object. Can be overridden by 1261 subclasses to copy data as desired from the deleted entry (but this 1262 superclass implementation must normally be called). 1263 """ 1264 self.id = old_object.id 1265 1266 1267 def clean_object(self): 1268 """ 1269 This method is called when an object is marked invalid. 1270 Subclasses should override this to clean up relationships that 1271 should no longer exist if the object were deleted. 1272 """ 1273 pass 1274 1275 1276 def delete(self): 1277 self.invalid = self.invalid 1278 assert not self.invalid 1279 self.invalid = True 1280 self.save() 1281 self.clean_object() 1282 1283 1284 @classmethod 1285 def get_valid_manager(cls): 1286 return cls.valid_objects 1287 1288 1289 class Manipulator(object): 1290 """ 1291 Force default manipulators to look only at valid objects - 1292 otherwise they will match against invalid objects when checking 1293 uniqueness. 1294 """ 1295 @classmethod 1296 def _prepare(cls, model): 1297 super(ModelWithInvalid.Manipulator, cls)._prepare(model) 1298 cls.manager = model.valid_objects 1299 1300 1301class ModelWithAttributes(object): 1302 """ 1303 Mixin class for models that have an attribute model associated with them. 1304 The attribute model is assumed to have its value field named "value". 1305 """ 1306 1307 def _get_attribute_model_and_args(self, attribute): 1308 """ 1309 Subclasses should override this to return a tuple (attribute_model, 1310 keyword_args), where attribute_model is a model class and keyword_args 1311 is a dict of args to pass to attribute_model.objects.get() to get an 1312 instance of the given attribute on this object. 1313 """ 1314 raise NotImplementedError 1315 1316 1317 def _is_replaced_by_static_attribute(self, attribute): 1318 """ 1319 Subclasses could override this to indicate whether it has static 1320 attributes. 1321 """ 1322 return False 1323 1324 1325 def set_attribute(self, attribute, value): 1326 if self._is_replaced_by_static_attribute(attribute): 1327 raise error.UnmodifiableAttributeException( 1328 'Failed to set attribute "%s" for host "%s" since it ' 1329 'is static. Use go/chromeos-skylab-inventory-tools to ' 1330 'modify this attribute.' % (attribute, self.hostname)) 1331 1332 attribute_model, get_args = self._get_attribute_model_and_args( 1333 attribute) 1334 attribute_object, _ = attribute_model.objects.get_or_create(**get_args) 1335 attribute_object.value = value 1336 attribute_object.save() 1337 1338 1339 def delete_attribute(self, attribute): 1340 if self._is_replaced_by_static_attribute(attribute): 1341 raise error.UnmodifiableAttributeException( 1342 'Failed to delete attribute "%s" for host "%s" since it ' 1343 'is static. Use go/chromeos-skylab-inventory-tools to ' 1344 'modify this attribute.' % (attribute, self.hostname)) 1345 1346 attribute_model, get_args = self._get_attribute_model_and_args( 1347 attribute) 1348 try: 1349 attribute_model.objects.get(**get_args).delete() 1350 except attribute_model.DoesNotExist: 1351 pass 1352 1353 1354 def set_or_delete_attribute(self, attribute, value): 1355 if value is None: 1356 self.delete_attribute(attribute) 1357 else: 1358 self.set_attribute(attribute, value) 1359 1360 1361class ModelWithHashManager(dbmodels.Manager): 1362 """Manager for use with the ModelWithHash abstract model class""" 1363 1364 def create(self, **kwargs): 1365 raise Exception('ModelWithHash manager should use get_or_create() ' 1366 'instead of create()') 1367 1368 1369 def get_or_create(self, **kwargs): 1370 kwargs['the_hash'] = self.model._compute_hash(**kwargs) 1371 return super(ModelWithHashManager, self).get_or_create(**kwargs) 1372 1373 1374class ModelWithHash(dbmodels.Model): 1375 """Superclass with methods for dealing with a hash column""" 1376 1377 the_hash = dbmodels.CharField(max_length=40, unique=True) 1378 1379 objects = ModelWithHashManager() 1380 1381 class Meta: 1382 abstract = True 1383 1384 1385 @classmethod 1386 def _compute_hash(cls, **kwargs): 1387 raise NotImplementedError('Subclasses must override _compute_hash()') 1388 1389 1390 def save(self, force_insert=False, **kwargs): 1391 """Prevents saving the model in most cases 1392 1393 We want these models to be immutable, so the generic save() operation 1394 will not work. These models should be instantiated through their the 1395 model.objects.get_or_create() method instead. 1396 1397 The exception is that save(force_insert=True) will be allowed, since 1398 that creates a new row. However, the preferred way to make instances of 1399 these models is through the get_or_create() method. 1400 """ 1401 if not force_insert: 1402 # Allow a forced insert to happen; if it's a duplicate, the unique 1403 # constraint will catch it later anyways 1404 raise Exception('ModelWithHash is immutable') 1405 super(ModelWithHash, self).save(force_insert=force_insert, **kwargs) 1406