• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright (c) 2014 The Chromium OS Authors. All rights reserved.
2# Use of this source code is governed by a BSD-style license that can be
3# found in the LICENSE file.
4
5import datetime
6import mox
7import unittest
8
9import common
10
11from autotest_lib.frontend import setup_django_environment
12from autotest_lib.frontend.afe import frontend_test_utils
13from autotest_lib.frontend.afe import models
14from autotest_lib.client.common_lib import error
15from autotest_lib.client.common_lib import global_config
16from autotest_lib.server.cros.dynamic_suite import frontend_wrappers
17from autotest_lib.scheduler.shard import shard_client
18
19
20class ShardClientTest(mox.MoxTestBase,
21                      frontend_test_utils.FrontendTestMixin):
22    """Unit tests for functions in shard_client.py"""
23
24
25    GLOBAL_AFE_HOSTNAME = 'foo_autotest'
26
27
28    def setUp(self):
29        super(ShardClientTest, self).setUp()
30
31        global_config.global_config.override_config_value(
32                'SHARD', 'global_afe_hostname', self.GLOBAL_AFE_HOSTNAME)
33
34        self._frontend_common_setup(fill_data=False)
35
36
37    def setup_mocks(self):
38        self.mox.StubOutClassWithMocks(frontend_wrappers, 'RetryingAFE')
39        self.afe = frontend_wrappers.RetryingAFE(server=mox.IgnoreArg(),
40                                                 delay_sec=5,
41                                                 timeout_min=5)
42
43
44    def setup_global_config(self):
45        global_config.global_config.override_config_value(
46                'SHARD', 'is_slave_shard', 'True')
47        global_config.global_config.override_config_value(
48                'SHARD', 'shard_hostname', 'host1')
49
50
51    def expect_heartbeat(self, shard_hostname='host1',
52                         known_job_ids=[], known_host_ids=[],
53                         known_host_statuses=[], hqes=[], jobs=[],
54                         side_effect=None, return_hosts=[], return_jobs=[],
55                         return_suite_keyvals=[], return_incorrect_hosts=[]):
56        call = self.afe.run(
57            'shard_heartbeat', shard_hostname=shard_hostname,
58            hqes=hqes, jobs=jobs,
59            known_job_ids=known_job_ids, known_host_ids=known_host_ids,
60            known_host_statuses=known_host_statuses,
61            )
62
63        if side_effect:
64            call = call.WithSideEffects(side_effect)
65
66        call.AndReturn({
67                'hosts': return_hosts,
68                'jobs': return_jobs,
69                'suite_keyvals': return_suite_keyvals,
70                'incorrect_host_ids': return_incorrect_hosts,
71            })
72
73
74    def tearDown(self):
75        self._frontend_common_teardown()
76
77        # Without this global_config will keep state over test cases
78        global_config.global_config.reset_config_values()
79
80
81    def _get_sample_serialized_host(self):
82        return {'aclgroup_set': [],
83                'dirty': True,
84                'hostattribute_set': [],
85                'hostname': u'host1',
86                u'id': 2,
87                'invalid': False,
88                'labels': [],
89                'leased': True,
90                'lock_time': None,
91                'locked': False,
92                'protection': 0,
93                'shard': None,
94                'status': u'Ready'}
95
96
97    def _get_sample_serialized_job(self):
98        return {'control_file': u'foo',
99                'control_type': 2,
100                'created_on': datetime.datetime(2014, 9, 23, 15, 56, 10, 0),
101                'dependency_labels': [{u'id': 1,
102                                       'invalid': False,
103                                       'kernel_config': u'',
104                                       'name': u'board:lumpy',
105                                       'only_if_needed': False,
106                                       'platform': False}],
107                'email_list': u'',
108                'hostqueueentry_set': [{'aborted': False,
109                                        'active': False,
110                                        'complete': False,
111                                        'deleted': False,
112                                        'execution_subdir': u'',
113                                        'finished_on': None,
114                                        u'id': 1,
115                                        'meta_host': {u'id': 1,
116                                                      'invalid': False,
117                                                      'kernel_config': u'',
118                                                      'name': u'board:lumpy',
119                                                      'only_if_needed': False,
120                                                      'platform': False},
121                                        'started_on': None,
122                                        'status': u'Queued'}],
123                u'id': 1,
124                'jobkeyval_set': [],
125                'max_runtime_hrs': 72,
126                'max_runtime_mins': 1440,
127                'name': u'dummy',
128                'owner': u'autotest_system',
129                'parse_failed_repair': True,
130                'priority': 40,
131                'parent_job_id': 0,
132                'reboot_after': 0,
133                'reboot_before': 1,
134                'run_reset': True,
135                'run_verify': False,
136                'shard': {'hostname': u'shard1', u'id': 1},
137                'synch_count': 0,
138                'test_retry': 0,
139                'timeout': 24,
140                'timeout_mins': 1440}
141
142
143    def _get_sample_serialized_suite_keyvals(self):
144        return {'id': 1,
145                'job_id': 0,
146                'key': 'test_key',
147                'value': 'test_value'}
148
149
150    def testHeartbeat(self):
151        """Trigger heartbeat, verify RPCs and persisting of the responses."""
152        self.setup_mocks()
153
154        global_config.global_config.override_config_value(
155                'SHARD', 'shard_hostname', 'host1')
156
157        self.expect_heartbeat(
158                return_hosts=[self._get_sample_serialized_host()],
159                return_jobs=[self._get_sample_serialized_job()],
160                return_suite_keyvals=[
161                        self._get_sample_serialized_suite_keyvals()])
162
163        modified_sample_host = self._get_sample_serialized_host()
164        modified_sample_host['hostname'] = 'host2'
165
166        self.expect_heartbeat(
167                return_hosts=[modified_sample_host],
168                known_host_ids=[modified_sample_host['id']],
169                known_host_statuses=[modified_sample_host['status']],
170                known_job_ids=[1])
171
172
173        def verify_upload_jobs_and_hqes(name, shard_hostname, jobs, hqes,
174                                        known_host_ids, known_host_statuses,
175                                        known_job_ids):
176            self.assertEqual(len(jobs), 1)
177            self.assertEqual(len(hqes), 1)
178            job, hqe = jobs[0], hqes[0]
179            self.assertEqual(hqe['status'], 'Completed')
180
181
182        self.expect_heartbeat(
183                jobs=mox.IgnoreArg(), hqes=mox.IgnoreArg(),
184                known_host_ids=[modified_sample_host['id']],
185                known_host_statuses=[modified_sample_host['status']],
186                known_job_ids=[], side_effect=verify_upload_jobs_and_hqes)
187
188        self.mox.ReplayAll()
189        sut = shard_client.get_shard_client()
190
191        sut.do_heartbeat()
192
193        # Check if dummy object was saved to DB
194        host = models.Host.objects.get(id=2)
195        self.assertEqual(host.hostname, 'host1')
196
197        # Check if suite keyval  was saved to DB
198        suite_keyval = models.JobKeyval.objects.filter(job_id=0)[0]
199        self.assertEqual(suite_keyval.key, 'test_key')
200
201        sut.do_heartbeat()
202
203        # Ensure it wasn't overwritten
204        host = models.Host.objects.get(id=2)
205        self.assertEqual(host.hostname, 'host1')
206
207        job = models.Job.objects.all()[0]
208        job.shard = None
209        job.save()
210        hqe = job.hostqueueentry_set.all()[0]
211        hqe.status = 'Completed'
212        hqe.save()
213
214        sut.do_heartbeat()
215
216
217        self.mox.VerifyAll()
218
219
220    def testRemoveInvalidHosts(self):
221        self.setup_mocks()
222        self.setup_global_config()
223
224        host_serialized = self._get_sample_serialized_host()
225        host_id = host_serialized[u'id']
226
227        # 1st heartbeat: return a host.
228        # 2nd heartbeat: "delete" that host. Also send a spurious extra ID
229        # that isn't present to ensure shard client doesn't crash. (Note: delete
230        # operation doesn't actually delete db entry. Djanjo model ;logic
231        # instead simply marks it as invalid.
232        # 3rd heartbeat: host is no longer present in shard's request.
233
234        self.expect_heartbeat(return_hosts=[host_serialized])
235        self.expect_heartbeat(known_host_ids=[host_id],
236                              known_host_statuses=[u'Ready'],
237                              return_incorrect_hosts=[host_id, 42])
238        self.expect_heartbeat()
239
240        self.mox.ReplayAll()
241        sut = shard_client.get_shard_client()
242
243        sut.do_heartbeat()
244        host = models.Host.smart_get(host_id)
245        self.assertFalse(host.invalid)
246
247        # Host should no longer "exist" after the invalidation.
248        # Why don't we simply count the number of hosts in db? Because the host
249        # actually remains int he db, but simply has it's invalid bit set to
250        # True.
251        sut.do_heartbeat()
252        with self.assertRaises(models.Host.DoesNotExist):
253            host = models.Host.smart_get(host_id)
254
255
256        # Subsequent heartbeat no longer passes the host id as a known host.
257        sut.do_heartbeat()
258
259
260    def testFailAndRedownloadJobs(self):
261        self.setup_mocks()
262        self.setup_global_config()
263
264        job1_serialized = self._get_sample_serialized_job()
265        job2_serialized = self._get_sample_serialized_job()
266        job2_serialized['id'] = 2
267        job2_serialized['hostqueueentry_set'][0]['id'] = 2
268
269        self.expect_heartbeat(return_jobs=[job1_serialized])
270        self.expect_heartbeat(return_jobs=[job1_serialized, job2_serialized])
271        self.expect_heartbeat(known_job_ids=[job1_serialized['id'],
272                                             job2_serialized['id']])
273        self.expect_heartbeat(known_job_ids=[job2_serialized['id']])
274
275        self.mox.ReplayAll()
276        sut = shard_client.get_shard_client()
277
278        original_process_heartbeat_response = sut.process_heartbeat_response
279        def failing_process_heartbeat_response(*args, **kwargs):
280            raise RuntimeError
281
282        sut.process_heartbeat_response = failing_process_heartbeat_response
283        self.assertRaises(RuntimeError, sut.do_heartbeat)
284
285        sut.process_heartbeat_response = original_process_heartbeat_response
286        sut.do_heartbeat()
287        sut.do_heartbeat()
288
289        job2 = models.Job.objects.get(pk=job1_serialized['id'])
290        job2.hostqueueentry_set.all().update(complete=True)
291
292        sut.do_heartbeat()
293
294        self.mox.VerifyAll()
295
296
297    def testFailAndRedownloadHosts(self):
298        self.setup_mocks()
299        self.setup_global_config()
300
301        host1_serialized = self._get_sample_serialized_host()
302        host2_serialized = self._get_sample_serialized_host()
303        host2_serialized['id'] = 3
304        host2_serialized['hostname'] = 'host2'
305
306        self.expect_heartbeat(return_hosts=[host1_serialized])
307        self.expect_heartbeat(return_hosts=[host1_serialized, host2_serialized])
308        self.expect_heartbeat(known_host_ids=[host1_serialized['id'],
309                                              host2_serialized['id']],
310                              known_host_statuses=[host1_serialized['status'],
311                                                   host2_serialized['status']])
312
313        self.mox.ReplayAll()
314        sut = shard_client.get_shard_client()
315
316        original_process_heartbeat_response = sut.process_heartbeat_response
317        def failing_process_heartbeat_response(*args, **kwargs):
318            raise RuntimeError
319
320        sut.process_heartbeat_response = failing_process_heartbeat_response
321        self.assertRaises(RuntimeError, sut.do_heartbeat)
322
323        self.assertEqual(models.Host.objects.count(), 0)
324
325        sut.process_heartbeat_response = original_process_heartbeat_response
326        sut.do_heartbeat()
327        sut.do_heartbeat()
328
329        self.mox.VerifyAll()
330
331
332    def testHeartbeatNoShardMode(self):
333        """Ensure an exception is thrown when run on a non-shard machine."""
334        self.mox.ReplayAll()
335
336        self.assertRaises(error.HeartbeatOnlyAllowedInShardModeException,
337                          shard_client.get_shard_client)
338
339        self.mox.VerifyAll()
340
341
342    def testLoop(self):
343        """Test looping over heartbeats and aborting that loop works."""
344        self.setup_mocks()
345        self.setup_global_config()
346
347        global_config.global_config.override_config_value(
348                'SHARD', 'heartbeat_pause_sec', '0.01')
349
350        self.expect_heartbeat()
351
352        sut = None
353
354        def shutdown_sut(*args, **kwargs):
355            sut.shutdown()
356
357        self.expect_heartbeat(side_effect=shutdown_sut)
358
359        self.mox.ReplayAll()
360        sut = shard_client.get_shard_client()
361        sut.loop()
362
363        self.mox.VerifyAll()
364
365
366if __name__ == '__main__':
367    unittest.main()
368