1#!/usr/bin/python 2#pylint: disable-msg=C0111 3 4import datetime 5import unittest 6 7import common 8from autotest_lib.frontend import setup_django_environment 9from autotest_lib.frontend.afe import frontend_test_utils 10from autotest_lib.client.common_lib import host_queue_entry_states 11from autotest_lib.database import database_connection 12from autotest_lib.frontend.afe import models, model_attributes 13from autotest_lib.scheduler import monitor_db 14from autotest_lib.scheduler import scheduler_lib 15from autotest_lib.scheduler import scheduler_models 16 17_DEBUG = False 18 19 20class BaseSchedulerModelsTest(unittest.TestCase, 21 frontend_test_utils.FrontendTestMixin): 22 _config_section = 'AUTOTEST_WEB' 23 24 def _do_query(self, sql): 25 self._database.execute(sql) 26 27 28 def _set_monitor_stubs(self): 29 # Clear the instance cache as this is a brand new database. 30 scheduler_models.DBObject._clear_instance_cache() 31 32 self._database = ( 33 database_connection.TranslatingDatabase.get_test_database( 34 translators=scheduler_lib._DB_TRANSLATORS)) 35 self._database.connect(db_type='django') 36 self._database.debug = _DEBUG 37 38 self.god.stub_with(scheduler_models, '_db', self._database) 39 40 41 def setUp(self): 42 self._frontend_common_setup() 43 self._set_monitor_stubs() 44 45 46 def tearDown(self): 47 self._database.disconnect() 48 self._frontend_common_teardown() 49 50 51 def _update_hqe(self, set, where=''): 52 query = 'UPDATE afe_host_queue_entries SET ' + set 53 if where: 54 query += ' WHERE ' + where 55 self._do_query(query) 56 57 58class DBObjectTest(BaseSchedulerModelsTest): 59 60 def test_compare_fields_in_row(self): 61 host = scheduler_models.Host(id=1) 62 fields = list(host._fields) 63 row_data = [getattr(host, fieldname) for fieldname in fields] 64 self.assertEqual({}, host._compare_fields_in_row(row_data)) 65 row_data[fields.index('hostname')] = 'spam' 66 self.assertEqual({'hostname': ('host1', 'spam')}, 67 host._compare_fields_in_row(row_data)) 68 row_data[fields.index('id')] = 23 69 self.assertEqual({'hostname': ('host1', 'spam'), 'id': (1, 23)}, 70 host._compare_fields_in_row(row_data)) 71 72 73 def test_compare_fields_in_row_datetime_ignores_microseconds(self): 74 datetime_with_us = datetime.datetime(2009, 10, 07, 12, 34, 56, 7890) 75 datetime_without_us = datetime.datetime(2009, 10, 07, 12, 34, 56, 0) 76 class TestTable(scheduler_models.DBObject): 77 _table_name = 'test_table' 78 _fields = ('id', 'test_datetime') 79 tt = TestTable(row=[1, datetime_without_us]) 80 self.assertEqual({}, tt._compare_fields_in_row([1, datetime_with_us])) 81 82 83 def test_always_query(self): 84 host_a = scheduler_models.Host(id=2) 85 self.assertEqual(host_a.hostname, 'host2') 86 self._do_query('UPDATE afe_hosts SET hostname="host2-updated" ' 87 'WHERE id=2') 88 host_b = scheduler_models.Host(id=2, always_query=True) 89 self.assert_(host_a is host_b, 'Cached instance not returned.') 90 self.assertEqual(host_a.hostname, 'host2-updated', 91 'Database was not re-queried') 92 93 # If either of these are called, a query was made when it shouldn't be. 94 host_a._compare_fields_in_row = lambda _: self.fail('eek! a query!') 95 host_a._update_fields_from_row = host_a._compare_fields_in_row 96 host_c = scheduler_models.Host(id=2, always_query=False) 97 self.assert_(host_a is host_c, 'Cached instance not returned') 98 99 100 def test_delete(self): 101 host = scheduler_models.Host(id=3) 102 host.delete() 103 host = self.assertRaises(scheduler_models.DBError, scheduler_models.Host, id=3, 104 always_query=False) 105 host = self.assertRaises(scheduler_models.DBError, scheduler_models.Host, id=3, 106 always_query=True) 107 108 def test_save(self): 109 # Dummy Job to avoid creating a one in the HostQueueEntry __init__. 110 class MockJob(object): 111 def __init__(self, id, row): 112 pass 113 def tag(self): 114 return 'MockJob' 115 self.god.stub_with(scheduler_models, 'Job', MockJob) 116 hqe = scheduler_models.HostQueueEntry( 117 new_record=True, 118 row=[0, 1, 2, 'Queued', None, 0, 0, 0, '.', None, False, None, 119 None]) 120 hqe.save() 121 new_id = hqe.id 122 # Force a re-query and verify that the correct data was stored. 123 scheduler_models.DBObject._clear_instance_cache() 124 hqe = scheduler_models.HostQueueEntry(id=new_id) 125 self.assertEqual(hqe.id, new_id) 126 self.assertEqual(hqe.job_id, 1) 127 self.assertEqual(hqe.host_id, 2) 128 self.assertEqual(hqe.status, 'Queued') 129 self.assertEqual(hqe.meta_host, None) 130 self.assertEqual(hqe.active, False) 131 self.assertEqual(hqe.complete, False) 132 self.assertEqual(hqe.deleted, False) 133 self.assertEqual(hqe.execution_subdir, '.') 134 self.assertEqual(hqe.started_on, None) 135 self.assertEqual(hqe.finished_on, None) 136 137 138class HostTest(BaseSchedulerModelsTest): 139 140 def setUp(self): 141 super(HostTest, self).setUp() 142 self.old_config = scheduler_models.RESPECT_STATIC_LABELS 143 144 145 def tearDown(self): 146 super(HostTest, self).tearDown() 147 scheduler_models.RESPECT_STATIC_LABELS = self.old_config 148 149 150 def _setup_static_labels(self): 151 label1 = models.Label.objects.create(name='non_static_label') 152 non_static_platform = models.Label.objects.create( 153 name='static_platform', platform=False) 154 models.ReplacedLabel.objects.create(label_id=non_static_platform.id) 155 156 static_label1 = models.StaticLabel.objects.create( 157 name='no_reference_label', platform=False) 158 static_platform = models.StaticLabel.objects.create( 159 name=non_static_platform.name, platform=True) 160 161 host1 = models.Host.objects.create(hostname='test_host') 162 host1.labels.add(label1) 163 host1.labels.add(non_static_platform) 164 host1.static_labels.add(static_label1) 165 host1.static_labels.add(static_platform) 166 host1.save() 167 return host1 168 169 170 def test_platform_and_labels_with_respect(self): 171 scheduler_models.RESPECT_STATIC_LABELS = True 172 test_host = self._setup_static_labels() 173 host = scheduler_models.Host(id=test_host.id) 174 platform, all_labels = host.platform_and_labels() 175 self.assertEqual(platform, 'static_platform') 176 self.assertNotIn('no_reference_label', all_labels) 177 self.assertEqual(all_labels, ['non_static_label', 'static_platform']) 178 179 180 def test_platform_and_labels_without_respect(self): 181 scheduler_models.RESPECT_STATIC_LABELS = False 182 test_host = self._setup_static_labels() 183 host = scheduler_models.Host(id=test_host.id) 184 platform, all_labels = host.platform_and_labels() 185 self.assertIsNone(platform) 186 self.assertEqual(all_labels, ['non_static_label', 'static_platform']) 187 188 189 def test_cmp_for_sort(self): 190 expected_order = [ 191 'alice', 'Host1', 'host2', 'host3', 'host09', 'HOST010', 192 'host10', 'host11', 'yolkfolk'] 193 hostname_idx = list(scheduler_models.Host._fields).index('hostname') 194 row = [None] * len(scheduler_models.Host._fields) 195 hosts = [] 196 for hostname in expected_order: 197 row[hostname_idx] = hostname 198 hosts.append(scheduler_models.Host(row=row, new_record=True)) 199 200 host1 = hosts[expected_order.index('Host1')] 201 host010 = hosts[expected_order.index('HOST010')] 202 host10 = hosts[expected_order.index('host10')] 203 host3 = hosts[expected_order.index('host3')] 204 alice = hosts[expected_order.index('alice')] 205 self.assertEqual(0, scheduler_models.Host.cmp_for_sort(host10, host10)) 206 self.assertEqual(1, scheduler_models.Host.cmp_for_sort(host10, host010)) 207 self.assertEqual(-1, scheduler_models.Host.cmp_for_sort(host010, host10)) 208 self.assertEqual(-1, scheduler_models.Host.cmp_for_sort(host1, host10)) 209 self.assertEqual(-1, scheduler_models.Host.cmp_for_sort(host1, host010)) 210 self.assertEqual(-1, scheduler_models.Host.cmp_for_sort(host3, host10)) 211 self.assertEqual(-1, scheduler_models.Host.cmp_for_sort(host3, host010)) 212 self.assertEqual(1, scheduler_models.Host.cmp_for_sort(host3, host1)) 213 self.assertEqual(-1, scheduler_models.Host.cmp_for_sort(host1, host3)) 214 self.assertEqual(-1, scheduler_models.Host.cmp_for_sort(alice, host3)) 215 self.assertEqual(1, scheduler_models.Host.cmp_for_sort(host3, alice)) 216 self.assertEqual(0, scheduler_models.Host.cmp_for_sort(alice, alice)) 217 218 hosts.sort(cmp=scheduler_models.Host.cmp_for_sort) 219 self.assertEqual(expected_order, [h.hostname for h in hosts]) 220 221 hosts.reverse() 222 hosts.sort(cmp=scheduler_models.Host.cmp_for_sort) 223 self.assertEqual(expected_order, [h.hostname for h in hosts]) 224 225 226class HostQueueEntryTest(BaseSchedulerModelsTest): 227 def _create_hqe(self, dependency_labels=(), **create_job_kwargs): 228 job = self._create_job(**create_job_kwargs) 229 for label in dependency_labels: 230 job.dependency_labels.add(label) 231 hqes = list(scheduler_models.HostQueueEntry.fetch(where='job_id=%d' % job.id)) 232 self.assertEqual(1, len(hqes)) 233 return hqes[0] 234 235 236 def _check_hqe_labels(self, hqe, expected_labels): 237 expected_labels = set(expected_labels) 238 label_names = set(label.name for label in hqe.get_labels()) 239 self.assertEqual(expected_labels, label_names) 240 241 242 def test_get_labels_empty(self): 243 hqe = self._create_hqe(hosts=[1]) 244 labels = list(hqe.get_labels()) 245 self.assertEqual([], labels) 246 247 248 def test_get_labels_metahost(self): 249 hqe = self._create_hqe(metahosts=[2]) 250 self._check_hqe_labels(hqe, ['label2']) 251 252 253 def test_get_labels_dependencies(self): 254 hqe = self._create_hqe(dependency_labels=(self.label3,), 255 metahosts=[1]) 256 self._check_hqe_labels(hqe, ['label1', 'label3']) 257 258 259 def setup_abort_test(self, agent_finished=True): 260 """Setup the variables for testing abort method. 261 262 @param agent_finished: True to mock agent is finished before aborting 263 the hqe. 264 @return hqe, dispatcher: Mock object of hqe and dispatcher to be used 265 to test abort method. 266 """ 267 hqe = self._create_hqe(hosts=[1]) 268 hqe.aborted = True 269 hqe.complete = False 270 hqe.status = models.HostQueueEntry.Status.STARTING 271 hqe.started_on = datetime.datetime.now() 272 273 dispatcher = self.god.create_mock_class(monitor_db.Dispatcher, 274 'Dispatcher') 275 agent = self.god.create_mock_class(monitor_db.Agent, 'Agent') 276 dispatcher.get_agents_for_entry.expect_call(hqe).and_return([agent]) 277 agent.is_done.expect_call().and_return(agent_finished) 278 return hqe, dispatcher 279 280 281 def test_abort_fail_with_unfinished_agent(self): 282 """abort should fail if the hqe still has agent not finished. 283 """ 284 hqe, dispatcher = self.setup_abort_test(agent_finished=False) 285 self.assertIsNone(hqe.finished_on) 286 with self.assertRaises(AssertionError): 287 hqe.abort(dispatcher) 288 self.god.check_playback() 289 # abort failed, finished_on should not be set 290 self.assertIsNone(hqe.finished_on) 291 292 293 def test_abort_success(self): 294 """abort should succeed if all agents for the hqe are finished. 295 """ 296 hqe, dispatcher = self.setup_abort_test(agent_finished=True) 297 self.assertIsNone(hqe.finished_on) 298 hqe.abort(dispatcher) 299 self.god.check_playback() 300 self.assertIsNotNone(hqe.finished_on) 301 302 303 def test_set_finished_on(self): 304 """Test that finished_on is set when hqe completes.""" 305 for status in host_queue_entry_states.Status.values: 306 hqe = self._create_hqe(hosts=[1]) 307 hqe.started_on = datetime.datetime.now() 308 hqe.job.update_field('shard_id', 3) 309 self.assertIsNone(hqe.finished_on) 310 hqe.set_status(status) 311 if status in host_queue_entry_states.COMPLETE_STATUSES: 312 self.assertIsNotNone(hqe.finished_on) 313 self.assertIsNone(hqe.job.shard_id) 314 else: 315 self.assertIsNone(hqe.finished_on) 316 self.assertEquals(hqe.job.shard_id, 3) 317 318 319class JobTest(BaseSchedulerModelsTest): 320 def setUp(self): 321 super(JobTest, self).setUp() 322 323 def _mock_create(**kwargs): 324 task = models.SpecialTask(**kwargs) 325 task.save() 326 self._tasks.append(task) 327 self.god.stub_with(models.SpecialTask.objects, 'create', _mock_create) 328 329 330 def _test_pre_job_tasks_helper(self, 331 reboot_before=model_attributes.RebootBefore.ALWAYS): 332 """ 333 Calls HQE._do_schedule_pre_job_tasks() and returns the created special 334 task 335 """ 336 self._tasks = [] 337 queue_entry = scheduler_models.HostQueueEntry.fetch('id = 1')[0] 338 queue_entry.job.reboot_before = reboot_before 339 queue_entry._do_schedule_pre_job_tasks() 340 return self._tasks 341 342 343 def test_job_request_abort(self): 344 django_job = self._create_job(hosts=[5, 6]) 345 job = scheduler_models.Job(django_job.id) 346 job.request_abort() 347 django_hqes = list(models.HostQueueEntry.objects.filter(job=job.id)) 348 for hqe in django_hqes: 349 self.assertTrue(hqe.aborted) 350 351 352 def _check_special_tasks(self, tasks, task_types): 353 self.assertEquals(len(tasks), len(task_types)) 354 for task, (task_type, queue_entry_id) in zip(tasks, task_types): 355 self.assertEquals(task.task, task_type) 356 self.assertEquals(task.host.id, 1) 357 if queue_entry_id: 358 self.assertEquals(task.queue_entry.id, queue_entry_id) 359 360 361 def test_run_asynchronous(self): 362 self._create_job(hosts=[1, 2]) 363 364 tasks = self._test_pre_job_tasks_helper() 365 366 self._check_special_tasks(tasks, [(models.SpecialTask.Task.RESET, 1)]) 367 368 369 def test_run_asynchronous_skip_verify(self): 370 job = self._create_job(hosts=[1, 2]) 371 job.run_verify = False 372 job.save() 373 374 tasks = self._test_pre_job_tasks_helper() 375 376 self._check_special_tasks(tasks, [(models.SpecialTask.Task.RESET, 1)]) 377 378 379 def test_run_synchronous_verify(self): 380 self._create_job(hosts=[1, 2], synchronous=True) 381 382 tasks = self._test_pre_job_tasks_helper() 383 384 self._check_special_tasks(tasks, [(models.SpecialTask.Task.RESET, 1)]) 385 386 387 def test_run_synchronous_skip_verify(self): 388 job = self._create_job(hosts=[1, 2], synchronous=True) 389 job.run_verify = False 390 job.save() 391 392 tasks = self._test_pre_job_tasks_helper() 393 394 self._check_special_tasks(tasks, [(models.SpecialTask.Task.RESET, 1)]) 395 396 397 def test_run_asynchronous_do_not_reset(self): 398 job = self._create_job(hosts=[1, 2]) 399 job.run_reset = False 400 job.run_verify = False 401 job.save() 402 403 tasks = self._test_pre_job_tasks_helper() 404 405 self.assertEquals(tasks, []) 406 407 408 def test_run_synchronous_do_not_reset_no_RebootBefore(self): 409 job = self._create_job(hosts=[1, 2], synchronous=True) 410 job.reboot_before = model_attributes.RebootBefore.NEVER 411 job.save() 412 413 tasks = self._test_pre_job_tasks_helper( 414 reboot_before=model_attributes.RebootBefore.NEVER) 415 416 self._check_special_tasks(tasks, [(models.SpecialTask.Task.VERIFY, 1)]) 417 418 419 def test_run_asynchronous_do_not_reset(self): 420 job = self._create_job(hosts=[1, 2], synchronous=False) 421 job.reboot_before = model_attributes.RebootBefore.NEVER 422 job.save() 423 424 tasks = self._test_pre_job_tasks_helper( 425 reboot_before=model_attributes.RebootBefore.NEVER) 426 427 self._check_special_tasks(tasks, [(models.SpecialTask.Task.VERIFY, 1)]) 428 429 430 def test_reboot_before_always(self): 431 job = self._create_job(hosts=[1]) 432 job.reboot_before = model_attributes.RebootBefore.ALWAYS 433 job.save() 434 435 tasks = self._test_pre_job_tasks_helper() 436 437 self._check_special_tasks(tasks, [ 438 (models.SpecialTask.Task.RESET, None) 439 ]) 440 441 442 def _test_reboot_before_if_dirty_helper(self): 443 job = self._create_job(hosts=[1]) 444 job.reboot_before = model_attributes.RebootBefore.IF_DIRTY 445 job.save() 446 447 tasks = self._test_pre_job_tasks_helper() 448 task_types = [(models.SpecialTask.Task.RESET, None)] 449 450 self._check_special_tasks(tasks, task_types) 451 452 453 def test_reboot_before_if_dirty(self): 454 models.Host.smart_get(1).update_object(dirty=True) 455 self._test_reboot_before_if_dirty_helper() 456 457 458 def test_reboot_before_not_dirty(self): 459 models.Host.smart_get(1).update_object(dirty=False) 460 self._test_reboot_before_if_dirty_helper() 461 462 463if __name__ == '__main__': 464 unittest.main() 465