1""" 2Extensions to Django's model logic. 3""" 4 5import django.core.exceptions 6import django.db.models.sql.where 7import six 8from autotest_lib.client.common_lib import error 9from autotest_lib.frontend.afe import rdb_model_extensions 10from django.db import connection, connections 11from django.db import models as dbmodels 12from django.db import transaction 13from django.db.models.sql import query 14 15 16class ValidationError(django.core.exceptions.ValidationError): 17 """\ 18 Data validation error in adding or updating an object. The associated 19 value is a dictionary mapping field names to error strings. 20 """ 21 22def _quote_name(name): 23 """Shorthand for connection.ops.quote_name().""" 24 return connection.ops.quote_name(name) 25 26 27class LeasedHostManager(dbmodels.Manager): 28 """Query manager for unleased, unlocked hosts. 29 """ 30 def get_query_set(self): 31 return (super(LeasedHostManager, self).get_query_set().filter( 32 leased=0, locked=0)) 33 34 35class ExtendedManager(dbmodels.Manager): 36 """\ 37 Extended manager supporting subquery filtering. 38 """ 39 40 class CustomQuery(query.Query): 41 """A custom query""" 42 43 def __init__(self, *args, **kwargs): 44 super(ExtendedManager.CustomQuery, self).__init__(*args, **kwargs) 45 self._custom_joins = [] 46 47 48 def clone(self, klass=None, **kwargs): 49 """Clones the query and returns the clone.""" 50 obj = super(ExtendedManager.CustomQuery, self).clone(klass) 51 obj._custom_joins = list(self._custom_joins) 52 return obj 53 54 55 def combine(self, rhs, connector): 56 """Combines query with another query.""" 57 super(ExtendedManager.CustomQuery, self).combine(rhs, connector) 58 if hasattr(rhs, '_custom_joins'): 59 self._custom_joins.extend(rhs._custom_joins) 60 61 62 def add_custom_join(self, table, condition, join_type, 63 condition_values=(), alias=None): 64 """Adds a custom join to the query.""" 65 if alias is None: 66 alias = table 67 join_dict = dict(table=table, 68 condition=condition, 69 condition_values=condition_values, 70 join_type=join_type, 71 alias=alias) 72 self._custom_joins.append(join_dict) 73 74 75 @classmethod 76 def convert_query(self, query_set): 77 """ 78 Convert the query set's "query" attribute to a CustomQuery. 79 """ 80 # Make a copy of the query set 81 query_set = query_set.all() 82 query_set.query = query_set.query.clone( 83 klass=ExtendedManager.CustomQuery, 84 _custom_joins=[]) 85 return query_set 86 87 88 class _WhereClause(object): 89 """Object allowing us to inject arbitrary SQL into Django queries. 90 91 By using this instead of extra(where=...), we can still freely combine 92 queries with & and |. 93 """ 94 def __init__(self, clause, values=()): 95 self._clause = clause 96 self._values = values 97 98 99 def as_sql(self, qn=None, connection=None): 100 """Converts the clause to SQL and returns it.""" 101 return self._clause, self._values 102 103 104 def relabel_aliases(self, change_map): 105 """Does nothing.""" 106 return 107 108 109 def add_join(self, query_set, join_table, join_key, join_condition='', 110 join_condition_values=(), join_from_key=None, alias=None, 111 suffix='', exclude=False, force_left_join=False): 112 """Add a join to query_set. 113 114 Join looks like this: 115 (INNER|LEFT) JOIN <join_table> AS <alias> 116 ON (<this table>.<join_from_key> = <join_table>.<join_key> 117 and <join_condition>) 118 119 @param join_table table to join to 120 @param join_key field referencing back to this model to use for the join 121 @param join_condition extra condition for the ON clause of the join 122 @param join_condition_values values to substitute into join_condition 123 @param join_from_key column on this model to join from. 124 @param alias alias to use for for join 125 @param suffix suffix to add to join_table for the join alias, if no 126 alias is provided 127 @param exclude if true, exclude rows that match this join (will use a 128 LEFT OUTER JOIN and an appropriate WHERE condition) 129 @param force_left_join - if true, a LEFT OUTER JOIN will be used 130 instead of an INNER JOIN regardless of other options 131 """ 132 join_from_table = query_set.model._meta.db_table 133 if join_from_key is None: 134 join_from_key = self.model._meta.pk.name 135 if alias is None: 136 alias = join_table + suffix 137 full_join_key = _quote_name(alias) + '.' + _quote_name(join_key) 138 full_join_condition = '%s = %s.%s' % (full_join_key, 139 _quote_name(join_from_table), 140 _quote_name(join_from_key)) 141 if join_condition: 142 full_join_condition += ' AND (' + join_condition + ')' 143 if exclude or force_left_join: 144 join_type = query_set.query.LOUTER 145 else: 146 join_type = query_set.query.INNER 147 148 query_set = self.CustomQuery.convert_query(query_set) 149 query_set.query.add_custom_join(join_table, 150 full_join_condition, 151 join_type, 152 condition_values=join_condition_values, 153 alias=alias) 154 155 if exclude: 156 query_set = query_set.extra(where=[full_join_key + ' IS NULL']) 157 158 return query_set 159 160 161 def _info_for_many_to_one_join(self, field, join_to_query, alias): 162 """ 163 @param field: the ForeignKey field on the related model 164 @param join_to_query: the query over the related model that we're 165 joining to 166 @param alias: alias of joined table 167 """ 168 info = {} 169 rhs_table = join_to_query.model._meta.db_table 170 info['rhs_table'] = rhs_table 171 info['rhs_column'] = field.column 172 info['lhs_column'] = field.rel.get_related_field().column 173 rhs_where = join_to_query.query.where 174 rhs_where.relabel_aliases({rhs_table: alias}) 175 compiler = join_to_query.query.get_compiler(using=join_to_query.db) 176 initial_clause, values = compiler.as_sql() 177 # initial_clause is compiled from `join_to_query`, which is a SELECT 178 # query returns at most one record. For it to be used in WHERE clause, 179 # it must be converted to a boolean value using EXISTS. 180 all_clauses = ('EXISTS (%s)' % initial_clause,) 181 if hasattr(join_to_query.query, 'extra_where'): 182 all_clauses += join_to_query.query.extra_where 183 info['where_clause'] = ( 184 ' AND '.join('(%s)' % clause for clause in all_clauses)) 185 info['values'] = values 186 return info 187 188 189 def _info_for_many_to_many_join(self, m2m_field, join_to_query, alias, 190 m2m_is_on_this_model): 191 """ 192 @param m2m_field: a Django field representing the M2M relationship. 193 It uses a pivot table with the following structure: 194 this model table <---> M2M pivot table <---> joined model table 195 @param join_to_query: the query over the related model that we're 196 joining to. 197 @param alias: alias of joined table 198 """ 199 if m2m_is_on_this_model: 200 # referenced field on this model 201 lhs_id_field = self.model._meta.pk 202 # foreign key on the pivot table referencing lhs_id_field 203 m2m_lhs_column = m2m_field.m2m_column_name() 204 # foreign key on the pivot table referencing rhd_id_field 205 m2m_rhs_column = m2m_field.m2m_reverse_name() 206 # referenced field on related model 207 rhs_id_field = m2m_field.rel.get_related_field() 208 else: 209 lhs_id_field = m2m_field.rel.get_related_field() 210 m2m_lhs_column = m2m_field.m2m_reverse_name() 211 m2m_rhs_column = m2m_field.m2m_column_name() 212 rhs_id_field = join_to_query.model._meta.pk 213 214 info = {} 215 info['rhs_table'] = m2m_field.m2m_db_table() 216 info['rhs_column'] = m2m_lhs_column 217 info['lhs_column'] = lhs_id_field.column 218 219 # select the ID of related models relevant to this join. we can only do 220 # a single join, so we need to gather this information up front and 221 # include it in the join condition. 222 rhs_ids = join_to_query.values_list(rhs_id_field.attname, flat=True) 223 assert len(rhs_ids) == 1, ('Many-to-many custom field joins can only ' 224 'match a single related object.') 225 rhs_id = rhs_ids[0] 226 227 info['where_clause'] = '%s.%s = %s' % (_quote_name(alias), 228 _quote_name(m2m_rhs_column), 229 rhs_id) 230 info['values'] = () 231 return info 232 233 234 def join_custom_field(self, query_set, join_to_query, alias, 235 left_join=True): 236 """Join to a related model to create a custom field in the given query. 237 238 This method is used to construct a custom field on the given query based 239 on a many-valued relationsip. join_to_query should be a simple query 240 (no joins) on the related model which returns at most one related row 241 per instance of this model. 242 243 For many-to-one relationships, the joined table contains the matching 244 row from the related model it one is related, NULL otherwise. 245 246 For many-to-many relationships, the joined table contains the matching 247 row if it's related, NULL otherwise. 248 """ 249 relationship_type, field = self.determine_relationship( 250 join_to_query.model) 251 252 if relationship_type == self.MANY_TO_ONE: 253 info = self._info_for_many_to_one_join(field, join_to_query, alias) 254 elif relationship_type == self.M2M_ON_RELATED_MODEL: 255 info = self._info_for_many_to_many_join( 256 m2m_field=field, join_to_query=join_to_query, alias=alias, 257 m2m_is_on_this_model=False) 258 elif relationship_type ==self.M2M_ON_THIS_MODEL: 259 info = self._info_for_many_to_many_join( 260 m2m_field=field, join_to_query=join_to_query, alias=alias, 261 m2m_is_on_this_model=True) 262 263 return self.add_join(query_set, info['rhs_table'], info['rhs_column'], 264 join_from_key=info['lhs_column'], 265 join_condition=info['where_clause'], 266 join_condition_values=info['values'], 267 alias=alias, 268 force_left_join=left_join) 269 270 271 def add_where(self, query_set, where, values=()): 272 """Adds a where clause to the query_set.""" 273 query_set = query_set.all() 274 query_set.query.where.add(self._WhereClause(where, values), 275 django.db.models.sql.where.AND) 276 return query_set 277 278 279 def _get_quoted_field(self, table, field): 280 return _quote_name(table) + '.' + _quote_name(field) 281 282 283 def get_key_on_this_table(self, key_field=None): 284 if key_field is None: 285 # default to primary key 286 key_field = self.model._meta.pk.column 287 return self._get_quoted_field(self.model._meta.db_table, key_field) 288 289 290 def escape_user_sql(self, sql): 291 """Escapes % in sql.""" 292 return sql.replace('%', '%%') 293 294 295 def _custom_select_query(self, query_set, selects): 296 """Execute a custom select query. 297 298 @param query_set: query set as returned by query_objects. 299 @param selects: Tables/Columns to select, e.g. tko_test_labels_list.id. 300 301 @returns: Result of the query as returned by cursor.fetchall(). 302 """ 303 compiler = query_set.query.get_compiler(using=query_set.db) 304 sql, params = compiler.as_sql() 305 from_ = sql[sql.find(' FROM'):] 306 307 if query_set.query.distinct: 308 distinct = 'DISTINCT ' 309 else: 310 distinct = '' 311 312 sql_query = ('SELECT ' + distinct + ','.join(selects) + from_) 313 # Chose the connection that's responsible for this type of object 314 cursor = connections[query_set.db].cursor() 315 cursor.execute(sql_query, params) 316 return cursor.fetchall() 317 318 319 def _is_relation_to(self, field, model_class): 320 return field.rel and field.rel.to is model_class 321 322 323 MANY_TO_ONE = object() 324 M2M_ON_RELATED_MODEL = object() 325 M2M_ON_THIS_MODEL = object() 326 327 def determine_relationship(self, related_model): 328 """ 329 Determine the relationship between this model and related_model. 330 331 related_model must have some sort of many-valued relationship to this 332 manager's model. 333 @returns (relationship_type, field), where relationship_type is one of 334 MANY_TO_ONE, M2M_ON_RELATED_MODEL, M2M_ON_THIS_MODEL, and field 335 is the Django field object for the relationship. 336 """ 337 # look for a foreign key field on related_model relating to this model 338 for field in related_model._meta.fields: 339 if self._is_relation_to(field, self.model): 340 return self.MANY_TO_ONE, field 341 342 # look for an M2M field on related_model relating to this model 343 for field in related_model._meta.many_to_many: 344 if self._is_relation_to(field, self.model): 345 return self.M2M_ON_RELATED_MODEL, field 346 347 # maybe this model has the many-to-many field 348 for field in self.model._meta.many_to_many: 349 if self._is_relation_to(field, related_model): 350 return self.M2M_ON_THIS_MODEL, field 351 352 raise ValueError('%s has no relation to %s' % 353 (related_model, self.model)) 354 355 356 def _get_pivot_iterator(self, base_objects_by_id, related_model): 357 """ 358 Determine the relationship between this model and related_model, and 359 return a pivot iterator. 360 @param base_objects_by_id: dict of instances of this model indexed by 361 their IDs 362 @returns a pivot iterator, which yields a tuple (base_object, 363 related_object) for each relationship between a base object and a 364 related object. all base_object instances come from base_objects_by_id. 365 Note -- this depends on Django model internals. 366 """ 367 relationship_type, field = self.determine_relationship(related_model) 368 if relationship_type == self.MANY_TO_ONE: 369 return self._many_to_one_pivot(base_objects_by_id, 370 related_model, field) 371 elif relationship_type == self.M2M_ON_RELATED_MODEL: 372 return self._many_to_many_pivot( 373 base_objects_by_id, related_model, field.m2m_db_table(), 374 field.m2m_reverse_name(), field.m2m_column_name()) 375 else: 376 assert relationship_type == self.M2M_ON_THIS_MODEL 377 return self._many_to_many_pivot( 378 base_objects_by_id, related_model, field.m2m_db_table(), 379 field.m2m_column_name(), field.m2m_reverse_name()) 380 381 382 def _many_to_one_pivot(self, base_objects_by_id, related_model, 383 foreign_key_field): 384 """ 385 @returns a pivot iterator - see _get_pivot_iterator() 386 """ 387 filter_data = {foreign_key_field.name + '__pk__in': 388 base_objects_by_id.keys()} 389 for related_object in related_model.objects.filter(**filter_data): 390 # lookup base object in the dict, rather than grabbing it from the 391 # related object. we need to return instances from the dict, not 392 # fresh instances of the same models (and grabbing model instances 393 # from the related models incurs a DB query each time). 394 base_object_id = getattr(related_object, foreign_key_field.attname) 395 base_object = base_objects_by_id[base_object_id] 396 yield base_object, related_object 397 398 399 def _query_pivot_table(self, base_objects_by_id, pivot_table, 400 pivot_from_field, pivot_to_field, related_model): 401 """ 402 @param id_list list of IDs of self.model objects to include 403 @param pivot_table the name of the pivot table 404 @param pivot_from_field a field name on pivot_table referencing 405 self.model 406 @param pivot_to_field a field name on pivot_table referencing the 407 related model. 408 @param related_model the related model 409 410 @returns pivot list of IDs (base_id, related_id) 411 """ 412 query = """ 413 SELECT %(from_field)s, %(to_field)s 414 FROM %(table)s 415 WHERE %(from_field)s IN (%(id_list)s) 416 """ % dict(from_field=pivot_from_field, 417 to_field=pivot_to_field, 418 table=pivot_table, 419 id_list=','.join( 420 str(id_) 421 for id_ in six.iterkeys(base_objects_by_id))) 422 423 # Chose the connection that's responsible for this type of object 424 # The databases for related_model and the current model will always 425 # be the same, related_model is just easier to obtain here because 426 # self is only a ExtendedManager, not the object. 427 cursor = connections[related_model.objects.db].cursor() 428 cursor.execute(query) 429 return cursor.fetchall() 430 431 432 def _many_to_many_pivot(self, base_objects_by_id, related_model, 433 pivot_table, pivot_from_field, pivot_to_field): 434 """ 435 @param pivot_table: see _query_pivot_table 436 @param pivot_from_field: see _query_pivot_table 437 @param pivot_to_field: see _query_pivot_table 438 @returns a pivot iterator - see _get_pivot_iterator() 439 """ 440 id_pivot = self._query_pivot_table(base_objects_by_id, pivot_table, 441 pivot_from_field, pivot_to_field, 442 related_model) 443 444 all_related_ids = list(set(related_id for base_id, related_id 445 in id_pivot)) 446 related_objects_by_id = related_model.objects.in_bulk(all_related_ids) 447 448 for base_id, related_id in id_pivot: 449 yield base_objects_by_id[base_id], related_objects_by_id[related_id] 450 451 452 def populate_relationships(self, base_objects, related_model, 453 related_list_name): 454 """ 455 For each instance of this model in base_objects, add a field named 456 related_list_name listing all the related objects of type related_model. 457 related_model must be in a many-to-one or many-to-many relationship with 458 this model. 459 @param base_objects - list of instances of this model 460 @param related_model - model class related to this model 461 @param related_list_name - attribute name in which to store the related 462 object list. 463 """ 464 if not base_objects: 465 # if we don't bail early, we'll get a SQL error later 466 return 467 468 # The default maximum value of a host parameter number in SQLite is 999. 469 # Exceed this will get a DatabaseError later. 470 batch_size = 900 471 for i in range(0, len(base_objects), batch_size): 472 base_objects_batch = base_objects[i:i + batch_size] 473 base_objects_by_id = dict((base_object._get_pk_val(), base_object) 474 for base_object in base_objects_batch) 475 pivot_iterator = self._get_pivot_iterator(base_objects_by_id, 476 related_model) 477 478 for base_object in base_objects_batch: 479 setattr(base_object, related_list_name, []) 480 481 for base_object, related_object in pivot_iterator: 482 getattr(base_object, related_list_name).append(related_object) 483 484 485class ModelWithInvalidQuerySet(dbmodels.query.QuerySet): 486 """ 487 QuerySet that handles delete() properly for models with an "invalid" bit 488 """ 489 def delete(self): 490 """Deletes the QuerySet.""" 491 for model in self: 492 model.delete() 493 494 495class ModelWithInvalidManager(ExtendedManager): 496 """ 497 Manager for objects with an "invalid" bit 498 """ 499 def get_query_set(self): 500 return ModelWithInvalidQuerySet(self.model) 501 502 503class ValidObjectsManager(ModelWithInvalidManager): 504 """ 505 Manager returning only objects with invalid=False. 506 """ 507 def get_query_set(self): 508 queryset = super(ValidObjectsManager, self).get_query_set() 509 return queryset.filter(invalid=False) 510 511 512class ModelExtensions(rdb_model_extensions.ModelValidators): 513 """\ 514 Mixin with convenience functions for models, built on top of 515 the model validators in rdb_model_extensions. 516 """ 517 # TODO: at least some of these functions really belong in a custom 518 # Manager class 519 520 521 SERIALIZATION_LINKS_TO_FOLLOW = set() 522 """ 523 To be able to send jobs and hosts to shards, it's necessary to find their 524 dependencies. 525 The most generic approach for this would be to traverse all relationships 526 to other objects recursively. This would list all objects that are related 527 in any way. 528 But this approach finds too many objects: If a host should be transferred, 529 all it's relationships would be traversed. This would find an acl group. 530 If then the acl group's relationships are traversed, the relationship 531 would be followed backwards and many other hosts would be found. 532 533 This mapping tells that algorithm which relations to follow explicitly. 534 """ 535 536 537 SERIALIZATION_LINKS_TO_KEEP = set() 538 """This set stores foreign keys which we don't want to follow, but 539 still want to include in the serialized dictionary. For 540 example, we follow the relationship `Host.hostattribute_set`, 541 but we do not want to follow `HostAttributes.host_id` back to 542 to Host, which would otherwise lead to a circle. However, we still 543 like to serialize HostAttribute.`host_id`.""" 544 545 SERIALIZATION_LOCAL_LINKS_TO_UPDATE = set() 546 """ 547 On deserializion, if the object to persist already exists, local fields 548 will only be updated, if their name is in this set. 549 """ 550 551 552 @classmethod 553 def convert_human_readable_values(cls, data, to_human_readable=False): 554 """\ 555 Performs conversions on user-supplied field data, to make it 556 easier for users to pass human-readable data. 557 558 For all fields that have choice sets, convert their values 559 from human-readable strings to enum values, if necessary. This 560 allows users to pass strings instead of the corresponding 561 integer values. 562 563 For all foreign key fields, call smart_get with the supplied 564 data. This allows the user to pass either an ID value or 565 the name of the object as a string. 566 567 If to_human_readable=True, perform the inverse - i.e. convert 568 numeric values to human readable values. 569 570 This method modifies data in-place. 571 """ 572 field_dict = cls.get_field_dict() 573 for field_name in data: 574 if field_name not in field_dict or data[field_name] is None: 575 continue 576 field_obj = field_dict[field_name] 577 # convert enum values 578 if field_obj.choices: 579 for choice_data in field_obj.choices: 580 # choice_data is (value, name) 581 if to_human_readable: 582 from_val, to_val = choice_data 583 else: 584 to_val, from_val = choice_data 585 if from_val == data[field_name]: 586 data[field_name] = to_val 587 break 588 # convert foreign key values 589 elif field_obj.rel: 590 dest_obj = field_obj.rel.to.smart_get(data[field_name], 591 valid_only=False) 592 if to_human_readable: 593 # parameterized_jobs do not have a name_field 594 if (field_name != 'parameterized_job' and 595 dest_obj.name_field is not None): 596 data[field_name] = getattr(dest_obj, 597 dest_obj.name_field) 598 else: 599 data[field_name] = dest_obj 600 601 602 603 604 def _validate_unique(self): 605 """\ 606 Validate that unique fields are unique. Django manipulators do 607 this too, but they're a huge pain to use manually. Trust me. 608 """ 609 errors = {} 610 cls = type(self) 611 field_dict = self.get_field_dict() 612 manager = cls.get_valid_manager() 613 for field_name, field_obj in six.iteritems(field_dict): 614 if not field_obj.unique: 615 continue 616 617 value = getattr(self, field_name) 618 if value is None and field_obj.auto_created: 619 # don't bother checking autoincrement fields about to be 620 # generated 621 continue 622 623 existing_objs = manager.filter(**{field_name : value}) 624 num_existing = existing_objs.count() 625 626 if num_existing == 0: 627 continue 628 if num_existing == 1 and existing_objs[0].id == self.id: 629 continue 630 errors[field_name] = ( 631 'This value must be unique (%s)' % (value)) 632 return errors 633 634 635 def _validate(self): 636 """ 637 First coerces all fields on this instance to their proper Python types. 638 Then runs validation on every field. Returns a dictionary of 639 field_name -> error_list. 640 641 Based on validate() from django.db.models.Model in Django 0.96, which 642 was removed in Django 1.0. It should reappear in a later version. See: 643 http://code.djangoproject.com/ticket/6845 644 """ 645 error_dict = {} 646 for f in self._meta.fields: 647 try: 648 python_value = f.to_python( 649 getattr(self, f.attname, f.get_default())) 650 except django.core.exceptions.ValidationError as e: 651 error_dict[f.name] = str(e) 652 continue 653 654 if not f.blank and not python_value: 655 error_dict[f.name] = 'This field is required.' 656 continue 657 658 setattr(self, f.attname, python_value) 659 660 return error_dict 661 662 663 def do_validate(self): 664 """Validate fields.""" 665 errors = self._validate() 666 unique_errors = self._validate_unique() 667 for field_name, error in six.iteritems(unique_errors): 668 errors.setdefault(field_name, error) 669 if errors: 670 raise ValidationError(errors) 671 672 673 # actually (externally) useful methods follow 674 675 @classmethod 676 def add_object(cls, data={}, **kwargs): 677 """\ 678 Returns a new object created with the given data (a dictionary 679 mapping field names to values). Merges any extra keyword args 680 into data. 681 """ 682 data = dict(data) 683 data.update(kwargs) 684 data = cls.prepare_data_args(data) 685 cls.convert_human_readable_values(data) 686 data = cls.provide_default_values(data) 687 688 obj = cls(**data) 689 obj.do_validate() 690 obj.save() 691 return obj 692 693 694 def update_object(self, data={}, **kwargs): 695 """\ 696 Updates the object with the given data (a dictionary mapping 697 field names to values). Merges any extra keyword args into 698 data. 699 """ 700 data = dict(data) 701 data.update(kwargs) 702 data = self.prepare_data_args(data) 703 self.convert_human_readable_values(data) 704 for field_name, value in six.iteritems(data): 705 setattr(self, field_name, value) 706 self.do_validate() 707 self.save() 708 709 710 # see query_objects() 711 _SPECIAL_FILTER_KEYS = ('query_start', 'query_limit', 'sort_by', 712 'extra_args', 'extra_where', 'no_distinct') 713 714 715 @classmethod 716 def _extract_special_params(cls, filter_data): 717 """ 718 @returns a tuple of dicts (special_params, regular_filters), where 719 special_params contains the parameters we handle specially and 720 regular_filters is the remaining data to be handled by Django. 721 """ 722 regular_filters = dict(filter_data) 723 special_params = {} 724 for key in cls._SPECIAL_FILTER_KEYS: 725 if key in regular_filters: 726 special_params[key] = regular_filters.pop(key) 727 return special_params, regular_filters 728 729 730 @classmethod 731 def apply_presentation(cls, query, filter_data): 732 """ 733 Apply presentation parameters -- sorting and paging -- to the given 734 query. 735 @returns new query with presentation applied 736 """ 737 special_params, _ = cls._extract_special_params(filter_data) 738 sort_by = special_params.get('sort_by', None) 739 if sort_by: 740 assert isinstance(sort_by, list) or isinstance(sort_by, tuple) 741 query = query.extra(order_by=sort_by) 742 743 query_start = special_params.get('query_start', None) 744 query_limit = special_params.get('query_limit', None) 745 if query_start is not None: 746 if query_limit is None: 747 raise ValueError('Cannot pass query_start without query_limit') 748 # query_limit is passed as a page size 749 query_limit += query_start 750 return query[query_start:query_limit] 751 752 753 @classmethod 754 def query_objects(cls, filter_data, valid_only=True, initial_query=None, 755 apply_presentation=True): 756 """\ 757 Returns a QuerySet object for querying the given model_class 758 with the given filter_data. Optional special arguments in 759 filter_data include: 760 -query_start: index of first return to return 761 -query_limit: maximum number of results to return 762 -sort_by: list of fields to sort on. prefixing a '-' onto a 763 field name changes the sort to descending order. 764 -extra_args: keyword args to pass to query.extra() (see Django 765 DB layer documentation) 766 -extra_where: extra WHERE clause to append 767 -no_distinct: if True, a DISTINCT will not be added to the SELECT 768 """ 769 special_params, regular_filters = cls._extract_special_params( 770 filter_data) 771 772 if initial_query is None: 773 if valid_only: 774 initial_query = cls.get_valid_manager() 775 else: 776 initial_query = cls.objects 777 778 query = initial_query.filter(**regular_filters) 779 780 use_distinct = not special_params.get('no_distinct', False) 781 if use_distinct: 782 query = query.distinct() 783 784 extra_args = special_params.get('extra_args', {}) 785 extra_where = special_params.get('extra_where', None) 786 if extra_where: 787 # escape %'s 788 extra_where = cls.objects.escape_user_sql(extra_where) 789 extra_args.setdefault('where', []).append(extra_where) 790 if extra_args: 791 query = query.extra(**extra_args) 792 # TODO: Use readonly connection for these queries. 793 # This has been disabled, because it's not used anyway, as the 794 # configured readonly user is the same as the real user anyway. 795 796 if apply_presentation: 797 query = cls.apply_presentation(query, filter_data) 798 799 return query 800 801 802 @classmethod 803 def query_count(cls, filter_data, initial_query=None): 804 """\ 805 Like query_objects, but retreive only the count of results. 806 """ 807 filter_data.pop('query_start', None) 808 filter_data.pop('query_limit', None) 809 query = cls.query_objects(filter_data, initial_query=initial_query) 810 return query.count() 811 812 813 @classmethod 814 def clean_object_dicts(cls, field_dicts): 815 """\ 816 Take a list of dicts corresponding to object (as returned by 817 query.values()) and clean the data to be more suitable for 818 returning to the user. 819 """ 820 for field_dict in field_dicts: 821 cls.clean_foreign_keys(field_dict) 822 cls._convert_booleans(field_dict) 823 cls.convert_human_readable_values(field_dict, 824 to_human_readable=True) 825 826 827 @classmethod 828 def list_objects(cls, filter_data, initial_query=None): 829 """\ 830 Like query_objects, but return a list of dictionaries. 831 """ 832 query = cls.query_objects(filter_data, initial_query=initial_query) 833 extra_fields = query.query.extra_select.keys() 834 field_dicts = [model_object.get_object_dict(extra_fields=extra_fields) 835 for model_object in query] 836 return field_dicts 837 838 839 @classmethod 840 def smart_get(cls, id_or_name, valid_only=True): 841 """\ 842 smart_get(integer) -> get object by ID 843 smart_get(string) -> get object by name_field 844 """ 845 if valid_only: 846 manager = cls.get_valid_manager() 847 else: 848 manager = cls.objects 849 850 if isinstance(id_or_name, six.integer_types): 851 return manager.get(pk=id_or_name) 852 if isinstance(id_or_name, six.string_types) and hasattr( 853 cls, 'name_field'): 854 return manager.get(**{cls.name_field : id_or_name}) 855 raise ValueError( 856 'Invalid positional argument: %s (%s)' % (id_or_name, 857 type(id_or_name))) 858 859 860 @classmethod 861 def smart_get_bulk(cls, id_or_name_list): 862 """Like smart_get, but for a list of ids or names""" 863 invalid_inputs = [] 864 result_objects = [] 865 for id_or_name in id_or_name_list: 866 try: 867 result_objects.append(cls.smart_get(id_or_name)) 868 except cls.DoesNotExist: 869 invalid_inputs.append(id_or_name) 870 if invalid_inputs: 871 raise cls.DoesNotExist('The following %ss do not exist: %s' 872 % (cls.__name__.lower(), 873 ', '.join(invalid_inputs))) 874 return result_objects 875 876 877 def get_object_dict(self, extra_fields=None): 878 """\ 879 Return a dictionary mapping fields to this object's values. @param 880 extra_fields: list of extra attribute names to include, in addition to 881 the fields defined on this object. 882 """ 883 fields = self.get_field_dict().keys() 884 if extra_fields: 885 fields += extra_fields 886 object_dict = dict((field_name, getattr(self, field_name)) 887 for field_name in fields) 888 self.clean_object_dicts([object_dict]) 889 self._postprocess_object_dict(object_dict) 890 return object_dict 891 892 893 def _postprocess_object_dict(self, object_dict): 894 """For subclasses to override.""" 895 pass 896 897 898 @classmethod 899 def get_valid_manager(cls): 900 return cls.objects 901 902 903 def _record_attributes(self, attributes): 904 """ 905 See on_attribute_changed. 906 """ 907 assert not isinstance(attributes, six.string_types) 908 self._recorded_attributes = dict((attribute, getattr(self, attribute)) 909 for attribute in attributes) 910 911 912 def _check_for_updated_attributes(self): 913 """ 914 See on_attribute_changed. 915 """ 916 for attribute, original_value in six.iteritems( 917 self._recorded_attributes): 918 new_value = getattr(self, attribute) 919 if original_value != new_value: 920 self.on_attribute_changed(attribute, original_value) 921 self._record_attributes(self._recorded_attributes.keys()) 922 923 924 def on_attribute_changed(self, attribute, old_value): 925 """ 926 Called whenever an attribute is updated. To be overridden. 927 928 To use this method, you must: 929 * call _record_attributes() from __init__() (after making the super 930 call) with a list of attributes for which you want to be notified upon 931 change. 932 * call _check_for_updated_attributes() from save(). 933 """ 934 pass 935 936 937 def serialize(self, include_dependencies=True): 938 """Serializes the object with dependencies. 939 940 The variable SERIALIZATION_LINKS_TO_FOLLOW defines which dependencies 941 this function will serialize with the object. 942 943 @param include_dependencies: Whether or not to follow relations to 944 objects this object depends on. 945 This parameter is used when uploading 946 jobs from a shard to the main, as the 947 main already has all the dependent 948 objects. 949 950 @returns: Dictionary representation of the object. 951 """ 952 serialized = {} 953 for field in self._meta.concrete_model._meta.local_fields: 954 if field.rel is None: 955 serialized[field.name] = field._get_val_from_obj(self) 956 elif field.name in self.SERIALIZATION_LINKS_TO_KEEP: 957 # attname will contain "_id" suffix for foreign keys, 958 # e.g. HostAttribute.host will be serialized as 'host_id'. 959 # Use it for easy deserialization. 960 serialized[field.attname] = field._get_val_from_obj(self) 961 962 if include_dependencies: 963 for link in self.SERIALIZATION_LINKS_TO_FOLLOW: 964 serialized[link] = self._serialize_relation(link) 965 966 return serialized 967 968 969 def _serialize_relation(self, link): 970 """Serializes dependent objects given the name of the relation. 971 972 @param link: Name of the relation to take objects from. 973 974 @returns For To-Many relationships a list of the serialized related 975 objects, for To-One relationships the serialized related object. 976 """ 977 try: 978 attr = getattr(self, link) 979 except AttributeError: 980 # One-To-One relationships that point to None may raise this 981 return None 982 983 if attr is None: 984 return None 985 if hasattr(attr, 'all'): 986 return [obj.serialize() for obj in attr.all()] 987 return attr.serialize() 988 989 990 @classmethod 991 def _split_local_from_foreign_values(cls, data): 992 """This splits local from foreign values in a serialized object. 993 994 @param data: The serialized object. 995 996 @returns A tuple of two lists, both containing tuples in the form 997 (link_name, link_value). The first list contains all links 998 for local fields, the second one contains those for foreign 999 fields/objects. 1000 """ 1001 links_to_local_values, links_to_related_values = [], [] 1002 for link, value in six.iteritems(data): 1003 if link in cls.SERIALIZATION_LINKS_TO_FOLLOW: 1004 # It's a foreign key 1005 links_to_related_values.append((link, value)) 1006 else: 1007 # It's a local attribute or a foreign key 1008 # we don't want to follow. 1009 links_to_local_values.append((link, value)) 1010 return links_to_local_values, links_to_related_values 1011 1012 1013 @classmethod 1014 def _filter_update_allowed_fields(cls, data): 1015 """Filters data and returns only files that updates are allowed on. 1016 1017 This is i.e. needed for syncing aborted bits from the main to shards. 1018 1019 Local links are only allowed to be updated, if they are in 1020 SERIALIZATION_LOCAL_LINKS_TO_UPDATE. 1021 Overwriting existing values is allowed in order to be able to sync i.e. 1022 the aborted bit from the main to a shard. 1023 1024 The allowlisting mechanism is in place to prevent overwriting local 1025 status: If all fields were overwritten, jobs would be completely be 1026 set back to their original (unstarted) state. 1027 1028 @param data: List with tuples of the form (link_name, link_value), as 1029 returned by _split_local_from_foreign_values. 1030 1031 @returns List of the same format as data, but only containing data for 1032 fields that updates are allowed on. 1033 """ 1034 return [pair for pair in data 1035 if pair[0] in cls.SERIALIZATION_LOCAL_LINKS_TO_UPDATE] 1036 1037 1038 @classmethod 1039 def delete_matching_record(cls, **filter_args): 1040 """Delete records matching the filter. 1041 1042 @param filter_args: Arguments for the django filter 1043 used to locate the record to delete. 1044 """ 1045 try: 1046 existing_record = cls.objects.get(**filter_args) 1047 except cls.DoesNotExist: 1048 return 1049 existing_record.delete() 1050 1051 1052 def _deserialize_local(self, data): 1053 """Set local attributes from a list of tuples. 1054 1055 @param data: List of tuples like returned by 1056 _split_local_from_foreign_values. 1057 """ 1058 if not data: 1059 return 1060 1061 for link, value in data: 1062 setattr(self, link, value) 1063 # Overwridden save() methods are prone to errors, so don't execute them. 1064 # This is because: 1065 # - the overwritten methods depend on ACL groups that don't yet exist 1066 # and don't handle errors 1067 # - the overwritten methods think this object already exists in the db 1068 # because the id is already set 1069 super(type(self), self).save() 1070 1071 1072 def _deserialize_relations(self, data): 1073 """Set foreign attributes from a list of tuples. 1074 1075 This deserialized the related objects using their own deserialize() 1076 function and then sets the relation. 1077 1078 @param data: List of tuples like returned by 1079 _split_local_from_foreign_values. 1080 """ 1081 for link, value in data: 1082 self._deserialize_relation(link, value) 1083 # See comment in _deserialize_local 1084 super(type(self), self).save() 1085 1086 1087 @classmethod 1088 def get_record(cls, data): 1089 """Retrieve a record with the data in the given input arg. 1090 1091 @param data: A dictionary containing the information to use in a query 1092 for data. If child models have different constraints of 1093 uniqueness they should override this model. 1094 1095 @return: An object with matching data. 1096 1097 @raises DoesNotExist: If a record with the given data doesn't exist. 1098 """ 1099 return cls.objects.get(id=data['id']) 1100 1101 1102 @classmethod 1103 def deserialize(cls, data): 1104 """Recursively deserializes and saves an object with it's dependencies. 1105 1106 This takes the result of the serialize method and creates objects 1107 in the database that are just like the original. 1108 1109 If an object of the same type with the same id already exists, it's 1110 local values will be left untouched, unless they are explicitly 1111 allowlisted in SERIALIZATION_LOCAL_LINKS_TO_UPDATE. 1112 1113 Deserialize will always recursively propagate to all related objects 1114 present in data though. 1115 I.e. this is necessary to add users to an already existing acl-group. 1116 1117 @param data: Representation of an object and its dependencies, as 1118 returned by serialize. 1119 1120 @returns: The object represented by data if it didn't exist before, 1121 otherwise the object that existed before and has the same type 1122 and id as the one described by data. 1123 """ 1124 if data is None: 1125 return None 1126 1127 local, related = cls._split_local_from_foreign_values(data) 1128 try: 1129 instance = cls.get_record(data) 1130 local = cls._filter_update_allowed_fields(local) 1131 except cls.DoesNotExist: 1132 instance = cls() 1133 1134 instance._deserialize_local(local) 1135 instance._deserialize_relations(related) 1136 1137 return instance 1138 1139 1140 def _check_update_from_shard(self, shard, updated_serialized, 1141 *args, **kwargs): 1142 """Check if an update sent from a shard is legitimate. 1143 1144 @raises error.UnallowedRecordsSentToMain if an update is not 1145 legitimate. 1146 """ 1147 raise NotImplementedError( 1148 '_check_update_from_shard must be implemented by subclass %s ' 1149 'for type %s' % type(self)) 1150 1151 1152 @transaction.commit_on_success 1153 def update_from_serialized(self, serialized): 1154 """Updates local fields of an existing object from a serialized form. 1155 1156 This is different than the normal deserialize() in the way that it 1157 does update local values, which deserialize doesn't, but doesn't 1158 recursively propagate to related objects, which deserialize() does. 1159 1160 The use case of this function is to update job records on the main 1161 after the jobs have been executed on a shard, as the main is not 1162 interested in updates for users, labels, specialtasks, etc. 1163 1164 @param serialized: Representation of an object and its dependencies, as 1165 returned by serialize. 1166 1167 @raises ValueError: if serialized contains related objects, i.e. not 1168 only local fields. 1169 """ 1170 local, related = ( 1171 self._split_local_from_foreign_values(serialized)) 1172 if related: 1173 raise ValueError('Serialized must not contain foreign ' 1174 'objects: %s' % related) 1175 1176 self._deserialize_local(local) 1177 1178 1179 def custom_deserialize_relation(self, link, data): 1180 """Allows overriding the deserialization behaviour by subclasses.""" 1181 raise NotImplementedError( 1182 'custom_deserialize_relation must be implemented by subclass %s ' 1183 'for relation %s' % (type(self), link)) 1184 1185 1186 def _deserialize_relation(self, link, data): 1187 """Deserializes related objects and sets references on this object. 1188 1189 Relations that point to a list of objects are handled automatically. 1190 For many-to-one or one-to-one relations custom_deserialize_relation 1191 must be overridden by the subclass. 1192 1193 Related objects are deserialized using their deserialize() method. 1194 Thereby they and their dependencies are created if they don't exist 1195 and saved to the database. 1196 1197 @param link: Name of the relation. 1198 @param data: Serialized representation of the related object(s). 1199 This means a list of dictionaries for to-many relations, 1200 just a dictionary for to-one relations. 1201 """ 1202 field = getattr(self, link) 1203 1204 if field and hasattr(field, 'all'): 1205 self._deserialize_2m_relation(link, data, field.model) 1206 else: 1207 self.custom_deserialize_relation(link, data) 1208 1209 1210 def _deserialize_2m_relation(self, link, data, related_class): 1211 """Deserialize related objects for one to-many relationship. 1212 1213 @param link: Name of the relation. 1214 @param data: Serialized representation of the related objects. 1215 This is a list with of dictionaries. 1216 @param related_class: A class representing a django model, with which 1217 this class has a one-to-many relationship. 1218 """ 1219 relation_set = getattr(self, link) 1220 if related_class == self.get_attribute_model(): 1221 # When deserializing a model together with 1222 # its attributes, clear all the exising attributes to ensure 1223 # db consistency. Note 'update' won't be sufficient, as we also 1224 # want to remove any attributes that no longer exist in |data|. 1225 # 1226 # core_filters is a dictionary of filters, defines how 1227 # RelatedMangager would query for the 1-to-many relationship. E.g. 1228 # Host.objects.get( 1229 # id=20).hostattribute_set.core_filters = {host_id:20} 1230 # We use it to delete objects related to the current object. 1231 related_class.objects.filter(**relation_set.core_filters).delete() 1232 for serialized in data: 1233 relation_set.add(related_class.deserialize(serialized)) 1234 1235 1236 @classmethod 1237 def get_attribute_model(cls): 1238 """Return the attribute model. 1239 1240 Subclass with attribute-like model should override this to 1241 return the attribute model class. This method will be 1242 called by _deserialize_2m_relation to determine whether 1243 to clear the one-to-many relations first on deserialization of object. 1244 """ 1245 return None 1246 1247 1248class ModelWithInvalid(ModelExtensions): 1249 """ 1250 Overrides model methods save() and delete() to support invalidation in 1251 place of actual deletion. Subclasses must have a boolean "invalid" 1252 field. 1253 """ 1254 1255 def save(self, *args, **kwargs): 1256 """Saves the model""" 1257 first_time = (self.id is None) 1258 if first_time: 1259 # see if this object was previously added and invalidated 1260 my_name = getattr(self, self.name_field) 1261 filters = {self.name_field : my_name, 'invalid' : True} 1262 try: 1263 old_object = self.__class__.objects.get(**filters) 1264 self.resurrect_object(old_object) 1265 except self.DoesNotExist: 1266 # no existing object 1267 pass 1268 1269 super(ModelWithInvalid, self).save(*args, **kwargs) 1270 1271 1272 def resurrect_object(self, old_object): 1273 """ 1274 Called when self is about to be saved for the first time and is actually 1275 "undeleting" a previously deleted object. Can be overridden by 1276 subclasses to copy data as desired from the deleted entry (but this 1277 superclass implementation must normally be called). 1278 """ 1279 self.id = old_object.id 1280 1281 1282 def clean_object(self): 1283 """ 1284 This method is called when an object is marked invalid. 1285 Subclasses should override this to clean up relationships that 1286 should no longer exist if the object were deleted. 1287 """ 1288 pass 1289 1290 1291 def delete(self): 1292 """Deletes the model""" 1293 self.invalid = self.invalid 1294 assert not self.invalid 1295 self.invalid = True 1296 self.save() 1297 self.clean_object() 1298 1299 1300 @classmethod 1301 def get_valid_manager(cls): 1302 return cls.valid_objects 1303 1304 1305 class Manipulator(object): 1306 """ 1307 Force default manipulators to look only at valid objects - 1308 otherwise they will match against invalid objects when checking 1309 uniqueness. 1310 """ 1311 @classmethod 1312 def _prepare(cls, model): 1313 super(ModelWithInvalid.Manipulator, cls)._prepare(model) 1314 cls.manager = model.valid_objects 1315 1316 1317class ModelWithAttributes(object): 1318 """ 1319 Mixin class for models that have an attribute model associated with them. 1320 The attribute model is assumed to have its value field named "value". 1321 """ 1322 1323 def _get_attribute_model_and_args(self, attribute): 1324 """ 1325 Subclasses should override this to return a tuple (attribute_model, 1326 keyword_args), where attribute_model is a model class and keyword_args 1327 is a dict of args to pass to attribute_model.objects.get() to get an 1328 instance of the given attribute on this object. 1329 """ 1330 raise NotImplementedError 1331 1332 1333 def _is_replaced_by_static_attribute(self, attribute): 1334 """ 1335 Subclasses could override this to indicate whether it has static 1336 attributes. 1337 """ 1338 return False 1339 1340 1341 def set_attribute(self, attribute, value): 1342 if self._is_replaced_by_static_attribute(attribute): 1343 raise error.UnmodifiableAttributeException( 1344 'Failed to set attribute "%s" for host "%s" since it ' 1345 'is static. Use go/chromeos-skylab-inventory-tools to ' 1346 'modify this attribute.' % (attribute, self.hostname)) 1347 1348 attribute_model, get_args = self._get_attribute_model_and_args( 1349 attribute) 1350 attribute_object, _ = attribute_model.objects.get_or_create(**get_args) 1351 attribute_object.value = value 1352 attribute_object.save() 1353 1354 1355 def delete_attribute(self, attribute): 1356 """Deletes an attribute""" 1357 if self._is_replaced_by_static_attribute(attribute): 1358 raise error.UnmodifiableAttributeException( 1359 'Failed to delete attribute "%s" for host "%s" since it ' 1360 'is static. Use go/chromeos-skylab-inventory-tools to ' 1361 'modify this attribute.' % (attribute, self.hostname)) 1362 1363 attribute_model, get_args = self._get_attribute_model_and_args( 1364 attribute) 1365 try: 1366 attribute_model.objects.get(**get_args).delete() 1367 except attribute_model.DoesNotExist: 1368 pass 1369 1370 1371 def set_or_delete_attribute(self, attribute, value): 1372 if value is None: 1373 self.delete_attribute(attribute) 1374 else: 1375 self.set_attribute(attribute, value) 1376 1377 1378class ModelWithHashManager(dbmodels.Manager): 1379 """Manager for use with the ModelWithHash abstract model class""" 1380 1381 def create(self, **kwargs): 1382 """Always raises exception.""" 1383 raise Exception('ModelWithHash manager should use get_or_create() ' 1384 'instead of create()') 1385 1386 1387 def get_or_create(self, **kwargs): 1388 kwargs['the_hash'] = self.model._compute_hash(**kwargs) 1389 return super(ModelWithHashManager, self).get_or_create(**kwargs) 1390 1391 1392class ModelWithHash(dbmodels.Model): 1393 """Superclass with methods for dealing with a hash column""" 1394 1395 the_hash = dbmodels.CharField(max_length=40, unique=True) 1396 1397 objects = ModelWithHashManager() 1398 1399 class Meta: 1400 """Overrides dbmodels.Model.Meta.""" 1401 abstract = True 1402 1403 1404 @classmethod 1405 def _compute_hash(cls, **kwargs): 1406 raise NotImplementedError('Subclasses must override _compute_hash()') 1407 1408 1409 def save(self, force_insert=False, **kwargs): 1410 """Prevents saving the model in most cases 1411 1412 We want these models to be immutable, so the generic save() operation 1413 will not work. These models should be instantiated through their the 1414 model.objects.get_or_create() method instead. 1415 1416 The exception is that save(force_insert=True) will be allowed, since 1417 that creates a new row. However, the preferred way to make instances of 1418 these models is through the get_or_create() method. 1419 """ 1420 if not force_insert: 1421 # Allow a forced insert to happen; if it's a duplicate, the unique 1422 # constraint will catch it later anyways 1423 raise Exception('ModelWithHash is immutable') 1424 super(ModelWithHash, self).save(force_insert=force_insert, **kwargs) 1425