• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1#!/usr/bin/python3
2#
3# Copyright (c) 2012 The Chromium OS Authors. All rights reserved.
4# Use of this source code is governed by a BSD-style license that can be
5# found in the LICENSE file.
6
7"""Unit tests for client/common_lib/cros/dev_server.py."""
8
9import six.moves.http_client
10import json
11import os
12import six
13from six.moves import urllib
14import time
15import unittest
16from unittest import mock
17from unittest.mock import patch, call
18
19import common
20from autotest_lib.client.bin import utils as bin_utils
21from autotest_lib.client.common_lib import android_utils
22from autotest_lib.client.common_lib import error
23from autotest_lib.client.common_lib import global_config
24from autotest_lib.client.common_lib.test_utils import comparators
25
26from autotest_lib.client.common_lib import utils
27from autotest_lib.client.common_lib.cros import dev_server
28from autotest_lib.client.common_lib.cros import retry
29
30
31def retry_mock(ExceptionToCheck, timeout_min, exception_to_raise=None,
32               label=None):
33    """A mock retry decorator to use in place of the actual one for testing.
34
35    @param ExceptionToCheck: the exception to check.
36    @param timeout_mins: Amount of time in mins to wait before timing out.
37    @param exception_to_raise: the exception to raise in retry.retry
38    @param label: used in debug messages
39
40    """
41    def inner_retry(func):
42        """The actual decorator.
43
44        @param func: Function to be called in decorator.
45
46        """
47        return func
48
49    return inner_retry
50
51
52class MockSshResponse(object):
53    """An ssh response mocked for testing."""
54
55    def __init__(self, output, exit_status=0):
56        self.stdout = output
57        self.exit_status = exit_status
58        self.stderr = 'SSH connection error occurred.'
59
60
61class MockSshError(error.CmdError):
62    """An ssh error response mocked for testing."""
63
64    def __init__(self):
65        self.result_obj = MockSshResponse('error', exit_status=255)
66
67
68E403 = urllib.error.HTTPError(url='',
69                              code=six.moves.http_client.FORBIDDEN,
70                              msg='Error 403',
71                              hdrs=None,
72                              fp=six.StringIO('Expected.'))
73E500 = urllib.error.HTTPError(url='',
74                              code=six.moves.http_client.INTERNAL_SERVER_ERROR,
75                              msg='Error 500',
76                              hdrs=None,
77                              fp=six.StringIO('Expected.'))
78CMD_ERROR = error.CmdError('error_cmd', MockSshError().result_obj)
79
80
81class RunCallTest(unittest.TestCase):
82    """Unit tests for ImageServerBase.run_call or DevServer.run_call."""
83
84    def setUp(self):
85        """Set up the test"""
86        self.test_call = 'http://nothing/test'
87        self.hostname = 'nothing'
88        self.contents = 'true'
89        self.contents_readline = ['file/one', 'file/two']
90        self.save_ssh_config = dev_server.ENABLE_SSH_CONNECTION_FOR_DEVSERVER
91        super(RunCallTest, self).setUp()
92
93        run_patcher = patch.object(utils, 'run', spec=True)
94        self.utils_run_mock = run_patcher.start()
95        self.addCleanup(run_patcher.stop)
96
97        urlopen_patcher = patch.object(urllib.request, 'urlopen', spec=True)
98        self.urlopen_mock = urlopen_patcher.start()
99        self.addCleanup(urlopen_patcher.stop)
100
101        sleep = mock.patch('time.sleep', autospec=True)
102        sleep.start()
103        self.addCleanup(sleep.stop)
104
105
106    def tearDown(self):
107        """Tear down the test"""
108        dev_server.ENABLE_SSH_CONNECTION_FOR_DEVSERVER = self.save_ssh_config
109        super(RunCallTest, self).tearDown()
110
111
112    def testRunCallHTTPWithDownDevserver(self):
113        """Test dev_server.ImageServerBase.run_call using http with arg:
114        (call)."""
115        dev_server.ENABLE_SSH_CONNECTION_FOR_DEVSERVER = False
116
117        urllib.request.urlopen.side_effect = [
118                six.StringIO(dev_server.ERR_MSG_FOR_DOWN_DEVSERVER),
119                six.StringIO(self.contents)
120        ]
121
122        response = dev_server.ImageServerBase.run_call(self.test_call)
123        self.assertEquals(self.contents, response)
124        self.urlopen_mock.assert_called_with(
125                comparators.Substring(self.test_call))
126
127    def testRunCallSSHWithDownDevserver(self):
128        """Test dev_server.ImageServerBase.run_call using http with arg:
129        (call)."""
130        dev_server.ENABLE_SSH_CONNECTION_FOR_DEVSERVER = True
131        with patch.object(utils, 'get_restricted_subnet') as subnet_patch:
132            utils.get_restricted_subnet.return_value = self.hostname
133
134            to_return1 = MockSshResponse(dev_server.ERR_MSG_FOR_DOWN_DEVSERVER)
135            to_return2 = MockSshResponse(self.contents)
136            utils.run.side_effect = [to_return1, to_return2]
137
138            response = dev_server.ImageServerBase.run_call(self.test_call)
139            self.assertEquals(self.contents, response)
140            dev_server.ENABLE_SSH_CONNECTION_FOR_DEVSERVER = False
141
142            self.utils_run_mock.assert_has_calls([
143                    call(comparators.Substring(self.test_call),
144                         timeout=mock.ANY),
145                    call(comparators.Substring(self.test_call),
146                         timeout=mock.ANY)
147            ])
148
149            subnet_patch.assert_called_with(self.hostname,
150                                            utils.get_all_restricted_subnets())
151
152    def testRunCallWithSingleCallHTTP(self):
153        """Test dev_server.ImageServerBase.run_call using http with arg:
154        (call)."""
155        dev_server.ENABLE_SSH_CONNECTION_FOR_DEVSERVER = False
156
157        urllib.request.urlopen.return_value = six.StringIO(self.contents)
158        response = dev_server.ImageServerBase.run_call(self.test_call)
159        self.assertEquals(self.contents, response)
160        self.urlopen_mock.assert_called_with(
161                comparators.Substring(self.test_call))
162
163    def testRunCallWithCallAndReadlineHTTP(self):
164        """Test dev_server.ImageServerBase.run_call using http with arg:
165        (call, readline=True)."""
166        dev_server.ENABLE_SSH_CONNECTION_FOR_DEVSERVER = False
167
168        urllib.request.urlopen.return_value = (six.StringIO('\n'.join(
169                self.contents_readline)))
170        response = dev_server.ImageServerBase.run_call(
171                self.test_call, readline=True)
172        self.assertEquals(self.contents_readline, response)
173        self.urlopen_mock.assert_called_with(
174                comparators.Substring(self.test_call))
175
176
177    def testRunCallWithCallAndTimeoutHTTP(self):
178        """Test dev_server.ImageServerBase.run_call using http with args:
179        (call, timeout=xxx)."""
180        dev_server.ENABLE_SSH_CONNECTION_FOR_DEVSERVER = False
181
182        urllib.request.urlopen.return_value = six.StringIO(self.contents)
183        response = dev_server.ImageServerBase.run_call(
184                self.test_call, timeout=60)
185        self.assertEquals(self.contents, response)
186        self.urlopen_mock.assert_called_with(comparators.Substring(
187                self.test_call),
188                                             data=None)
189
190
191    def testRunCallWithSingleCallSSH(self):
192        """Test dev_server.ImageServerBase.run_call using ssh with arg:
193        (call)."""
194        dev_server.ENABLE_SSH_CONNECTION_FOR_DEVSERVER = True
195        with patch.object(utils, 'get_restricted_subnet') as subnet_patch:
196            utils.get_restricted_subnet.return_value = self.hostname
197
198            to_return = MockSshResponse(self.contents)
199            utils.run.return_value = to_return
200            response = dev_server.ImageServerBase.run_call(self.test_call)
201            self.assertEquals(self.contents, response)
202            subnet_patch.assert_called_with(self.hostname,
203                                            utils.get_all_restricted_subnets())
204            expected_str = comparators.Substring(self.test_call)
205            self.utils_run_mock.assert_called_with(expected_str,
206                                                   timeout=mock.ANY)
207
208    def testRunCallWithCallAndReadlineSSH(self):
209        """Test dev_server.ImageServerBase.run_call using ssh with args:
210        (call, readline=True)."""
211        dev_server.ENABLE_SSH_CONNECTION_FOR_DEVSERVER = True
212        with patch.object(utils, 'get_restricted_subnet') as subnet_patch:
213            utils.get_restricted_subnet.return_value = self.hostname
214
215            to_return = MockSshResponse('\n'.join(self.contents_readline))
216            utils.run.return_value = to_return
217
218            response = dev_server.ImageServerBase.run_call(self.test_call,
219                                                           readline=True)
220
221            self.assertEquals(self.contents_readline, response)
222            subnet_patch.assert_called_with(self.hostname,
223                                            utils.get_all_restricted_subnets())
224
225            expected_str = comparators.Substring(self.test_call)
226            self.utils_run_mock.assert_called_with(expected_str,
227                                                   timeout=mock.ANY)
228
229
230    def testRunCallWithCallAndTimeoutSSH(self):
231        """Test dev_server.ImageServerBase.run_call using ssh with args:
232        (call, timeout=xxx)."""
233        dev_server.ENABLE_SSH_CONNECTION_FOR_DEVSERVER = True
234        with patch.object(utils, 'get_restricted_subnet') as subnet_patch:
235            utils.get_restricted_subnet.return_value = self.hostname
236
237            to_return = MockSshResponse(self.contents)
238            utils.run.return_value = to_return
239            response = dev_server.ImageServerBase.run_call(self.test_call,
240                                                           timeout=60)
241
242            self.assertEquals(self.contents, response)
243            subnet_patch.assert_called_with(self.hostname,
244                                            utils.get_all_restricted_subnets())
245
246            expected_str = comparators.Substring(self.test_call)
247            self.utils_run_mock.assert_called_with(expected_str,
248                                                   timeout=mock.ANY)
249
250
251    def testRunCallWithExceptionHTTP(self):
252        """Test dev_server.ImageServerBase.run_call using http with raising
253        exception."""
254        dev_server.ENABLE_SSH_CONNECTION_FOR_DEVSERVER = False
255        urllib.request.urlopen.side_effect = E500
256        self.assertRaises(urllib.error.HTTPError,
257                          dev_server.ImageServerBase.run_call,
258                          self.test_call)
259        self.urlopen_mock.assert_called_with(
260                comparators.Substring(self.test_call))
261
262
263    def testRunCallWithExceptionSSH(self):
264        """Test dev_server.ImageServerBase.run_call using ssh with raising
265        exception."""
266        dev_server.ENABLE_SSH_CONNECTION_FOR_DEVSERVER = True
267        with patch.object(utils, 'get_restricted_subnet') as subnet_patch:
268            utils.get_restricted_subnet.return_value = self.hostname
269
270            utils.run.side_effect = MockSshError()
271
272            self.assertRaises(error.CmdError,
273                              dev_server.ImageServerBase.run_call,
274                              self.test_call)
275            subnet_patch.assert_called_with(self.hostname,
276                                            utils.get_all_restricted_subnets())
277            self.utils_run_mock.assert_called_with(comparators.Substring(
278                    self.test_call),
279                                                   timeout=mock.ANY)
280
281    def testRunCallByDevServerHTTP(self):
282        """Test dev_server.DevServer.run_call, which uses http, and can be
283        directly called by CrashServer."""
284        urllib.request.urlopen.return_value = six.StringIO(self.contents)
285        response = dev_server.DevServer.run_call(
286               self.test_call, timeout=60)
287        self.assertEquals(self.contents, response)
288        self.urlopen_mock.assert_called_with(comparators.Substring(
289                self.test_call),
290                                             data=None)
291
292
293class DevServerTest(unittest.TestCase):
294    """Unit tests for dev_server.DevServer.
295
296    @var _HOST: fake dev server host address.
297    """
298
299    _HOST = 'http://nothing'
300    _CRASH_HOST = 'http://nothing-crashed'
301    _CONFIG = global_config.global_config
302
303
304    def setUp(self):
305        """Set up the test"""
306        super(DevServerTest, self).setUp()
307        self.crash_server = dev_server.CrashServer(DevServerTest._CRASH_HOST)
308        self.dev_server = dev_server.ImageServer(DevServerTest._HOST)
309        self.android_dev_server = dev_server.AndroidBuildServer(
310                DevServerTest._HOST)
311        patcher = patch.object(utils, 'run', spec=True)
312        self.utils_run_mock = patcher.start()
313        self.addCleanup(patcher.stop)
314
315        patcher2 = patch.object(urllib.request, 'urlopen', spec=True)
316        self.urlopen_mock = patcher2.start()
317        self.addCleanup(patcher2.stop)
318
319        patcher3 = patch.object(dev_server.ImageServerBase, 'run_call')
320        self.run_call_mock = patcher3.start()
321        self.addCleanup(patcher3.stop)
322
323        patcher4 = patch.object(os.path, 'exists', spec=True)
324        self.os_exists_mock = patcher4.start()
325        self.addCleanup(patcher4.stop)
326
327        # Hide local restricted_subnets setting.
328        dev_server.RESTRICTED_SUBNETS = []
329
330        _read_json_response_from_devserver = patch.object(
331                dev_server.ImageServer, '_read_json_response_from_devserver')
332        self._read_json_mock = _read_json_response_from_devserver.start()
333        self.addCleanup(_read_json_response_from_devserver.stop)
334
335        sleep = mock.patch('time.sleep', autospec=True)
336        sleep.start()
337        self.addCleanup(sleep.stop)
338
339        self.image_name = 'fake/image'
340        first_staged = comparators.Substrings(
341                [self._HOST, self.image_name, 'stage?'])
342        second_staged = comparators.Substrings(
343                [self._HOST, self.image_name, 'is_staged'])
344        self.staged_calls = [call(first_staged), call(second_staged)]
345
346    def _standard_assert_calls(self):
347        """Assert the standard calls are made."""
348        bad_host, good_host = 'http://bad_host:99', 'http://good_host:8080'
349
350        argument1 = comparators.Substring(bad_host)
351        argument2 = comparators.Substring(good_host)
352        calls = [
353                call(argument1, timeout=mock.ANY),
354                call(argument2, timeout=mock.ANY)
355        ]
356        self.run_call_mock.assert_has_calls(calls)
357
358    def testSimpleResolve(self):
359        """One devserver, verify we resolve to it."""
360        with patch.object(dev_server,
361             '_get_dev_server_list') as server_list_patch, \
362                patch.object(dev_server.ImageServer,
363                'devserver_healthy') as devserver_healthy_patch:
364
365            dev_server._get_dev_server_list.return_value = ([
366                    DevServerTest._HOST
367            ])
368
369            dev_server.ImageServer.devserver_healthy.return_value = True
370            devserver = dev_server.ImageServer.resolve('my_build')
371            self.assertEquals(devserver.url(), DevServerTest._HOST)
372
373            server_list_patch.assert_called_with()
374            devserver_healthy_patch.assert_called_with(DevServerTest._HOST)
375
376
377    def testResolveWithFailure(self):
378        """Ensure we rehash on a failed ping on a bad_host."""
379
380        with patch.object(dev_server, '_get_dev_server_list'):
381            bad_host, good_host = 'http://bad_host:99', 'http://good_host:8080'
382            dev_server._get_dev_server_list.return_value = ([
383                    bad_host, good_host
384            ])
385
386            # Mock out bad ping failure by raising devserver exception.
387            dev_server.ImageServerBase.run_call.side_effect = [
388                    dev_server.DevServerException(), '{"free_disk": 1024}'
389            ]
390
391            host = dev_server.ImageServer.resolve(
392                    0)  # Using 0 as it'll hash to 0.
393            self.assertEquals(host.url(), good_host)
394            self._standard_assert_calls()
395
396
397    def testResolveWithFailureURLError(self):
398        """Ensure we rehash on a failed ping using http on a bad_host after
399        urlerror."""
400        # Set retry.retry to retry_mock for just returning the original
401        # method for this test. This is to save waiting time for real retry,
402        # which is defined by dev_server.DEVSERVER_SSH_TIMEOUT_MINS.
403        # Will reset retry.retry to real retry at the end of this test.
404        real_retry = retry.retry
405        retry.retry = retry_mock
406
407        with patch.object(dev_server, '_get_dev_server_list'):
408
409            bad_host, good_host = 'http://bad_host:99', 'http://good_host:8080'
410            dev_server._get_dev_server_list.return_value = ([
411                    bad_host, good_host
412            ])
413
414            # Mock out bad ping failure by raising devserver exception.
415            dev_server.ImageServerBase.run_call.side_effect = [
416                    urllib.error.URLError('urlopen connection timeout'),
417                    '{"free_disk": 1024}'
418            ]
419
420            host = dev_server.ImageServer.resolve(
421                    0)  # Using 0 as it'll hash to 0.
422            self.assertEquals(host.url(), good_host)
423
424            retry.retry = real_retry
425            self._standard_assert_calls()
426
427
428    def testResolveWithManyDevservers(self):
429        """Should be able to return different urls with multiple devservers."""
430
431        with patch.object(dev_server.ImageServer, 'servers'), \
432                patch.object(dev_server.DevServer,
433                 'devserver_healthy') as devserver_healthy_patch:
434
435            host0_expected = 'http://host0:8080'
436            host1_expected = 'http://host1:8082'
437
438            dev_server.ImageServer.servers.return_value = ([
439                    host0_expected, host1_expected
440            ])
441            dev_server.ImageServer.devserver_healthy.return_value = True
442            dev_server.ImageServer.devserver_healthy.return_value = True
443
444            host0 = dev_server.ImageServer.resolve(0)
445            host1 = dev_server.ImageServer.resolve(1)
446
447            self.assertEqual(host0.url(), host0_expected)
448            self.assertEqual(host1.url(), host1_expected)
449
450            calls = [call(host0_expected), call(host1_expected)]
451            devserver_healthy_patch.assert_has_calls(calls)
452
453
454    def testSuccessfulTriggerDownloadSync(self):
455        """Call the dev server's download method with synchronous=True."""
456        with patch.object(dev_server.ImageServer,
457                          '_finish_download') as download_patch:
458
459            dev_server.ImageServerBase.run_call.side_effect = [
460                    'Success', 'True'
461            ]
462            self.dev_server._finish_download.return_value = None
463
464            # Synchronous case requires a call to finish download.
465            self.dev_server.trigger_download(self.image_name, synchronous=True)
466
467            download_patch.assert_called_with(self.image_name, mock.ANY,
468                                              mock.ANY)
469
470            self.run_call_mock.assert_has_calls(self.staged_calls)
471
472
473    def testSuccessfulTriggerDownloadASync(self):
474        """Call the dev server's download method with synchronous=False."""
475        dev_server.ImageServerBase.run_call.side_effect = ['Success', 'True']
476        self.dev_server.trigger_download(self.image_name, synchronous=False)
477
478        self.run_call_mock.assert_has_calls(self.staged_calls)
479
480    def testURLErrorRetryTriggerDownload(self):
481        """Should retry on URLError, but pass through real exception."""
482        with patch.object(time, 'sleep'):
483
484            refused = urllib.error.URLError('[Errno 111] Connection refused')
485            dev_server.ImageServerBase.run_call.side_effect = refused
486            time.sleep(mock.ANY)
487
488            dev_server.ImageServerBase.run_call.side_effect = E403
489            self.assertRaises(dev_server.DevServerException,
490                              self.dev_server.trigger_download, '')
491            self.run_call_mock.assert_called()
492
493
494    def testErrorTriggerDownload(self):
495        """Should call the dev server's download method using http, fail
496        gracefully."""
497        dev_server.ImageServerBase.run_call.side_effect = E500
498        self.assertRaises(dev_server.DevServerException,
499                          self.dev_server.trigger_download,
500                          '')
501        self.run_call_mock.assert_called()
502
503
504    def testForbiddenTriggerDownload(self):
505        """Should call the dev server's download method using http,
506        get exception."""
507        dev_server.ImageServerBase.run_call.side_effect = E500
508        self.assertRaises(dev_server.DevServerException,
509                          self.dev_server.trigger_download,
510                          '')
511        self.run_call_mock.assert_called()
512
513
514    def testCmdErrorTriggerDownload(self):
515        """Should call the dev server's download method using ssh, retry
516        trigger_download when getting error.CmdError, raise exception for
517        urllib2.HTTPError."""
518
519        dev_server.ImageServerBase.run_call.side_effect = [CMD_ERROR, E500]
520        self.assertRaises(dev_server.DevServerException,
521                          self.dev_server.trigger_download,
522                          '')
523        self.run_call_mock.assert_has_calls([call(mock.ANY), call(mock.ANY)])
524
525
526    def testSuccessfulFinishDownload(self):
527        """Should successfully call the dev server's finish download method."""
528        dev_server.ImageServerBase.run_call.side_effect = ['Success', 'True']
529
530        # Synchronous case requires a call to finish download.
531        self.dev_server.finish_download(self.image_name)  # Raises on failure.
532
533        self.run_call_mock.assert_has_calls(self.staged_calls)
534
535    def testErrorFinishDownload(self):
536        """Should call the dev server's finish download method using http, fail
537        gracefully."""
538        dev_server.ImageServerBase.run_call.side_effect = E500
539        self.assertRaises(dev_server.DevServerException,
540                          self.dev_server.finish_download,
541                          '')
542        self.run_call_mock.assert_called()
543
544    def testCmdErrorFinishDownload(self):
545        """Should call the dev server's finish download method using ssh,
546        retry finish_download when getting error.CmdError, raise exception
547        for urllib2.HTTPError."""
548        dev_server.ImageServerBase.run_call.side_effect = [CMD_ERROR, E500]
549
550        self.assertRaises(dev_server.DevServerException,
551                          self.dev_server.finish_download,
552                          '')
553        self.run_call_mock.assert_has_calls([call(mock.ANY), call(mock.ANY)])
554
555    def testListControlFiles(self):
556        """Should successfully list control files from the dev server."""
557        control_files = ['file/one', 'file/two']
558        argument = comparators.Substrings([self._HOST, self.image_name])
559        dev_server.ImageServerBase.run_call.return_value = control_files
560
561        paths = self.dev_server.list_control_files(self.image_name)
562        self.assertEquals(len(paths), 2)
563        for f in control_files:
564            self.assertTrue(f in paths)
565
566        self.run_call_mock.assert_called_with(argument, readline=True)
567
568    def testFailedListControlFiles(self):
569        """Should call the dev server's list-files method using http, get
570        exception."""
571        dev_server.ImageServerBase.run_call.side_effect = E500
572        self.assertRaises(dev_server.DevServerException,
573                          self.dev_server.list_control_files,
574                          '')
575        self.run_call_mock.assert_called_with(mock.ANY, readline=True)
576
577
578    def testExplodingListControlFiles(self):
579        """Should call the dev server's list-files method using http, get
580        exception."""
581        dev_server.ImageServerBase.run_call.side_effect = E403
582        self.assertRaises(dev_server.DevServerException,
583                          self.dev_server.list_control_files, '')
584        self.run_call_mock.assert_called_with(mock.ANY, readline=True)
585
586    def testCmdErrorListControlFiles(self):
587        """Should call the dev server's list-files method using ssh, retry
588        list_control_files when getting error.CmdError, raise exception for
589        urllib2.HTTPError."""
590        dev_server.ImageServerBase.run_call.side_effect = [CMD_ERROR, E500]
591        self.assertRaises(dev_server.DevServerException,
592                          self.dev_server.list_control_files,
593                          '')
594        self.run_call_mock.assert_called_with(mock.ANY, readline=True)
595
596    def testListSuiteControls(self):
597        """Should successfully list all contents of control files from the dev
598        server."""
599        control_contents = ['control file one', 'control file two']
600        argument = comparators.Substrings([self._HOST, self.image_name])
601
602        dev_server.ImageServerBase.run_call.return_value = (
603                json.dumps(control_contents))
604
605        file_contents = self.dev_server.list_suite_controls(self.image_name)
606        self.assertEquals(len(file_contents), 2)
607        for f in control_contents:
608            self.assertTrue(f in file_contents)
609
610        self.run_call_mock.assert_called_with(argument)
611
612    def testFailedListSuiteControls(self):
613        """Should call the dev server's list_suite_controls method using http,
614        get exception."""
615        dev_server.ImageServerBase.run_call.side_effect = E500
616
617        self.assertRaises(dev_server.DevServerException,
618                          self.dev_server.list_suite_controls,
619                          '')
620        self.run_call_mock.assert_called()
621
622
623    def testExplodingListSuiteControls(self):
624        """Should call the dev server's list_suite_controls method using http,
625        get exception."""
626        dev_server.ImageServerBase.run_call.side_effect = E403
627
628        self.assertRaises(dev_server.DevServerException,
629                          self.dev_server.list_suite_controls,
630                          '')
631        self.run_call_mock.assert_called()
632
633    def testCmdErrorListSuiteControls(self):
634        """Should call the dev server's list_suite_controls method using ssh,
635        retry list_suite_controls when getting error.CmdError, raise exception
636        for urllib2.HTTPError."""
637        dev_server.ImageServerBase.run_call.side_effect = [CMD_ERROR, E500]
638
639        self.assertRaises(dev_server.DevServerException,
640                          self.dev_server.list_suite_controls,
641                          '')
642        self.run_call_mock.assert_has_calls([call(mock.ANY), call(mock.ANY)])
643
644    def testGetControlFile(self):
645        """Should successfully get a control file from the dev server."""
646        file = 'file/one'
647        contents = 'Multi-line\nControl File Contents\n'
648        argument = comparators.Substrings([self._HOST, self.image_name, file])
649
650        dev_server.ImageServerBase.run_call.return_value = contents
651
652        self.assertEquals(
653                self.dev_server.get_control_file(self.image_name, file),
654                contents)
655
656        self.run_call_mock.assert_called_with(argument)
657
658    def testErrorGetControlFile(self):
659        """Should try to get the contents of a control file using http, get
660        exception."""
661        dev_server.ImageServerBase.run_call.side_effect = E500
662        self.assertRaises(dev_server.DevServerException,
663                          self.dev_server.get_control_file,
664                          '', '')
665        self.run_call_mock.assert_called()
666
667    def testForbiddenGetControlFile(self):
668        """Should try to get the contents of a control file using http, get
669        exception."""
670        dev_server.ImageServerBase.run_call.side_effect = E403
671        self.assertRaises(dev_server.DevServerException,
672                          self.dev_server.get_control_file,
673                          '', '')
674        self.run_call_mock.assert_called()
675
676
677    def testCmdErrorGetControlFile(self):
678        """Should try to get the contents of a control file using ssh, retry
679        get_control_file when getting error.CmdError, raise exception for
680        urllib2.HTTPError."""
681        dev_server.ImageServerBase.run_call.side_effect = [CMD_ERROR, E500]
682
683        self.assertRaises(dev_server.DevServerException,
684                          self.dev_server.get_control_file, '', '')
685        self.run_call_mock.assert_has_calls([call(mock.ANY), call(mock.ANY)])
686
687
688    def testGetLatestBuild(self):
689        """Should successfully return a build for a given target."""
690        with patch.object(dev_server.ImageServer, 'servers'), \
691            patch.object(dev_server.ImageServer,
692                         'devserver_healthy') as devserver_patch:
693
694            dev_server.ImageServer.servers.return_value = [self._HOST]
695            dev_server.ImageServer.devserver_healthy.return_value = True
696
697            target = 'x86-generic-release'
698            build_string = 'R18-1586.0.0-a1-b1514'
699            argument = comparators.Substrings([self._HOST, target])
700
701            dev_server.ImageServerBase.run_call.return_value = build_string
702
703            build = dev_server.ImageServer.get_latest_build(target)
704            self.assertEquals(build_string, build)
705
706            devserver_patch.assert_called_with(self._HOST)
707            self.run_call_mock.assert_called_with(argument)
708
709
710    def testGetLatestBuildWithManyDevservers(self):
711        """Should successfully return newest build with multiple devservers."""
712        with patch.object(dev_server.ImageServer, 'servers'), \
713            patch.object(dev_server.ImageServer,
714                         'devserver_healthy') as devserver_patch:
715
716            host0_expected = 'http://host0:8080'
717            host1_expected = 'http://host1:8082'
718
719            dev_server.ImageServer.servers.return_value = ([
720                    host0_expected, host1_expected
721            ])
722
723            dev_server.ImageServer.devserver_healthy.return_value = True
724
725            dev_server.ImageServer.devserver_healthy.return_value = True
726
727            target = 'x86-generic-release'
728            build_string1 = 'R9-1586.0.0-a1-b1514'
729            build_string2 = 'R19-1586.0.0-a1-b3514'
730            argument1 = comparators.Substrings([host0_expected, target])
731            argument2 = comparators.Substrings([host1_expected, target])
732
733            dev_server.ImageServerBase.run_call.side_effect = ([
734                    build_string1, build_string2
735            ])
736
737            build = dev_server.ImageServer.get_latest_build(target)
738            self.assertEquals(build_string2, build)
739            devserver_patch.assert_has_calls(
740                    [call(host0_expected),
741                     call(host1_expected)])
742
743            self.run_call_mock.assert_has_calls(
744                    [call(argument1), call(argument2)])
745
746
747    def testCrashesAreSetToTheCrashServer(self):
748        """Should send symbolicate dump rpc calls to crash_server."""
749        call = self.crash_server.build_call('symbolicate_dump')
750        self.assertTrue(call.startswith(self._CRASH_HOST))
751
752
753    def _stageTestHelper(self, artifacts=[], files=[], archive_url=None):
754        """Helper to test combos of files/artifacts/urls with stage call."""
755        expected_archive_url = archive_url
756        if not archive_url:
757            expected_archive_url = 'gs://my_default_url'
758            image_patch = patch.object(dev_server, '_get_image_storage_server')
759            self.image_server_mock = image_patch.start()
760            self.addCleanup(image_patch.stop)
761            dev_server._get_image_storage_server.return_value = (
762                    'gs://my_default_url')
763            name = 'fake/image'
764        else:
765            # This is embedded in the archive_url. Not needed.
766            name = ''
767
768        argument1 = comparators.Substrings([
769                expected_archive_url, name,
770                'artifacts=%s' % ','.join(artifacts),
771                'files=%s' % ','.join(files), 'stage?'
772        ])
773        argument2 = comparators.Substrings([
774                expected_archive_url, name,
775                'artifacts=%s' % ','.join(artifacts),
776                'files=%s' % ','.join(files), 'is_staged?'
777        ])
778
779        dev_server.ImageServerBase.run_call.side_effect = ['Success', 'True']
780
781        self.dev_server.stage_artifacts(name, artifacts, files, archive_url)
782        self.run_call_mock.assert_has_calls([call(argument1), call(argument2)])
783
784
785    def testStageArtifactsBasic(self):
786        """Basic functionality to stage artifacts (similar to
787        trigger_download)."""
788        self._stageTestHelper(artifacts=['full_payload', 'stateful'])
789
790
791    def testStageArtifactsBasicWithFiles(self):
792        """Basic functionality to stage artifacts (similar to
793        trigger_download)."""
794        self._stageTestHelper(artifacts=['full_payload', 'stateful'],
795                              files=['taco_bell.coupon'])
796
797
798    def testStageArtifactsOnlyFiles(self):
799        """Test staging of only file artifacts."""
800        self._stageTestHelper(files=['tasty_taco_bell.coupon'])
801
802
803    def testStageWithArchiveURL(self):
804        """Basic functionality to stage artifacts (similar to
805        trigger_download)."""
806        self._stageTestHelper(files=['tasty_taco_bell.coupon'],
807                              archive_url='gs://tacos_galore/my/dir')
808
809
810    def testStagedFileUrl(self):
811        """Tests that the staged file url looks right."""
812        devserver_label = 'x86-mario-release/R30-1234.0.0'
813        url = self.dev_server.get_staged_file_url('stateful.tgz',
814                                                  devserver_label)
815        expected_url = '/'.join([self._HOST, 'static', devserver_label,
816                                 'stateful.tgz'])
817        self.assertEquals(url, expected_url)
818
819        devserver_label = 'something_complex/that/you_MIGHT/hate'
820        url = self.dev_server.get_staged_file_url('chromiumos_image.bin',
821                                                  devserver_label)
822        expected_url = '/'.join([self._HOST, 'static', devserver_label,
823                                 'chromiumos_image.bin'])
824        self.assertEquals(url, expected_url)
825
826
827    def _StageTimeoutHelper(self):
828        """Helper class for testing staging timeout."""
829        call_patch = patch.object(dev_server.ImageServer, 'call_and_wait')
830        self.call_mock = call_patch.start()
831        self.addCleanup(call_patch.stop)
832        dev_server.ImageServer.call_and_wait.side_effect = (
833                bin_utils.TimeoutError())
834
835    def _VerifyTimeoutHelper(self):
836        self.call_mock.assert_called_with(call_name='stage',
837                                          artifacts=mock.ANY,
838                                          files=mock.ANY,
839                                          archive_url=mock.ANY,
840                                          error_message=mock.ANY)
841
842
843    def test_StageArtifactsTimeout(self):
844        """Test DevServerException is raised when stage_artifacts timed out."""
845        self._StageTimeoutHelper()
846
847        self.assertRaises(dev_server.DevServerException,
848                          self.dev_server.stage_artifacts,
849                          image='fake/image', artifacts=['full_payload'])
850        self._VerifyTimeoutHelper()
851
852
853    def test_TriggerDownloadTimeout(self):
854        """Test DevServerException is raised when trigger_download timed out."""
855        self._StageTimeoutHelper()
856        self.assertRaises(dev_server.DevServerException,
857                          self.dev_server.trigger_download,
858                          image='fake/image')
859        self._VerifyTimeoutHelper()
860
861    def test_FinishDownloadTimeout(self):
862        """Test DevServerException is raised when finish_download timed out."""
863        self._StageTimeoutHelper()
864        self.assertRaises(dev_server.DevServerException,
865                          self.dev_server.finish_download,
866                          image='fake/image')
867        self._VerifyTimeoutHelper()
868
869
870    def test_compare_load(self):
871        """Test load comparison logic.
872        """
873        load_high_cpu = {'devserver': 'http://devserver_1:8082',
874                         dev_server.DevServer.CPU_LOAD: 100.0,
875                         dev_server.DevServer.NETWORK_IO: 1024*1024*1.0,
876                         dev_server.DevServer.DISK_IO: 1024*1024.0}
877        load_high_network = {'devserver': 'http://devserver_1:8082',
878                             dev_server.DevServer.CPU_LOAD: 1.0,
879                             dev_server.DevServer.NETWORK_IO: 1024*1024*100.0,
880                             dev_server.DevServer.DISK_IO: 1024*1024*1.0}
881        load_1 = {'devserver': 'http://devserver_1:8082',
882                  dev_server.DevServer.CPU_LOAD: 1.0,
883                  dev_server.DevServer.NETWORK_IO: 1024*1024*1.0,
884                  dev_server.DevServer.DISK_IO: 1024*1024*2.0}
885        load_2 = {'devserver': 'http://devserver_1:8082',
886                  dev_server.DevServer.CPU_LOAD: 1.0,
887                  dev_server.DevServer.NETWORK_IO: 1024*1024*1.0,
888                  dev_server.DevServer.DISK_IO: 1024*1024*1.0}
889        self.assertFalse(dev_server._is_load_healthy(load_high_cpu))
890        self.assertFalse(dev_server._is_load_healthy(load_high_network))
891        self.assertTrue(dev_server._compare_load(load_1, load_2) > 0)
892
893
894    def _testSuccessfulTriggerDownloadAndroid(self, synchronous=True):
895        """Call the dev server's download method with given synchronous
896        setting.
897
898        @param synchronous: True to call the download method synchronously.
899        """
900        target = 'test_target'
901        branch = 'test_branch'
902        build_id = '123456'
903        artifacts = android_utils.AndroidArtifacts.get_artifacts_for_reimage(
904                None)
905        with patch.object(dev_server.AndroidBuildServer, '_finish_download'):
906
907            argument1 = comparators.Substrings(
908                    [self._HOST, target, branch, build_id, 'stage?'])
909            argument2 = comparators.Substrings(
910                    [self._HOST, target, branch, build_id, 'is_staged?'])
911
912            dev_server.ImageServerBase.run_call.side_effect = [
913                    'Success', 'True'
914            ]
915
916            if synchronous:
917                android_build_info = {
918                        'target': target,
919                        'build_id': build_id,
920                        'branch': branch
921                }
922                build = (dev_server.ANDROID_BUILD_NAME_PATTERN %
923                         android_build_info)
924                self.android_dev_server._finish_download(build,
925                                                         artifacts,
926                                                         '',
927                                                         target=target,
928                                                         build_id=build_id,
929                                                         branch=branch)
930
931            # Synchronous case requires a call to finish download.
932            self.android_dev_server.trigger_download(synchronous=synchronous,
933                                                     target=target,
934                                                     build_id=build_id,
935                                                     branch=branch)
936            self.run_call_mock.assert_has_calls(
937                    [call(argument1), call(argument2)])
938
939
940    def testSuccessfulTriggerDownloadAndroidSync(self):
941        """Call the dev server's download method with synchronous=True."""
942        self._testSuccessfulTriggerDownloadAndroid(synchronous=True)
943
944
945    def testSuccessfulTriggerDownloadAndroidAsync(self):
946        """Call the dev server's download method with synchronous=False."""
947        self._testSuccessfulTriggerDownloadAndroid(synchronous=False)
948
949
950    @unittest.expectedFailure
951    def testGetUnrestrictedDevservers(self):
952        """Test method get_unrestricted_devservers works as expected."""
953        restricted_devserver = 'http://192.168.0.100:8080'
954        unrestricted_devserver = 'http://172.1.1.3:8080'
955        with patch.object(dev_server.ImageServer, 'servers') as servers_patch:
956            dev_server.ImageServer.servers.return_value = ([
957                    restricted_devserver, unrestricted_devserver
958            ])
959            # crbug.com/1027277: get_unrestricted_devservers() now returns all
960            # servers.
961            self.assertEqual(
962                    dev_server.ImageServer.get_unrestricted_devservers([
963                            ('192.168.0.0', 24)
964                    ]), [unrestricted_devserver])
965
966            servers_patch.assert_called_once()
967
968    def testGetUnrestrictedDevserversReturnsAll(self):
969        """Test method get_unrestricted_devservers works as expected."""
970        restricted_devserver = 'http://192.168.0.100:8080'
971        unrestricted_devserver = 'http://172.1.1.3:8080'
972        with patch.object(dev_server.ImageServer, 'servers') as servers_patch:
973            dev_server.ImageServer.servers.return_value = ([
974                    restricted_devserver, unrestricted_devserver
975            ])
976            # crbug.com/1027277: get_unrestricted_devservers() now returns all
977            # servers.
978            self.assertEqual(
979                    dev_server.ImageServer.get_unrestricted_devservers([
980                            ('192.168.0.0', 24)
981                    ]), [restricted_devserver, unrestricted_devserver])
982
983            servers_patch.assert_called_once()
984
985    def testDevserverHealthy(self):
986        """Test which types of connections that method devserver_healthy uses
987        for different types of DevServer.
988
989        CrashServer always adopts DevServer.run_call.
990        ImageServer and AndroidBuildServer use ImageServerBase.run_call.
991        """
992        argument = comparators.Substring(self._HOST)
993
994        # for testing CrashServer
995
996        with patch.object(dev_server.DevServer, 'run_call'):
997            # for testing CrashServer
998            dev_server.DevServer.run_call.return_value = '{"free_disk": 1024}'
999
1000            # for testing ImageServer
1001            dev_server.ImageServer.run_call.return_value = (
1002                    '{"free_disk": 1024}')
1003
1004            # for testing AndroidBuildServer
1005            dev_server.AndroidBuildServer.run_call.return_value = (
1006                    '{"free_disk": 1024}')
1007
1008            self.assertTrue(
1009                    dev_server.CrashServer.devserver_healthy(self._HOST))
1010            self.assertTrue(
1011                    dev_server.ImageServer.devserver_healthy(self._HOST))
1012            self.assertTrue(
1013                    dev_server.AndroidBuildServer.devserver_healthy(
1014                            self._HOST))
1015
1016            dev_server.DevServer.run_call.assert_called_with(argument,
1017                                                             timeout=mock.ANY)
1018            dev_server.ImageServer.run_call.assert_called_with(
1019                    argument, timeout=mock.ANY)
1020            dev_server.AndroidBuildServer.run_call.assert_called_with(
1021                    argument, timeout=mock.ANY)
1022
1023
1024    def testLocateFile(self):
1025        """Test locating files for AndriodBuildServer."""
1026        file_name = 'fake_file'
1027        artifacts = ['full_payload', 'stateful']
1028        build = 'fake_build'
1029
1030        argument = comparators.Substrings([file_name, build, 'locate_file'])
1031        dev_server.ImageServerBase.run_call.return_value = 'file_path'
1032
1033        file_location = 'http://nothing/static/fake_build/file_path'
1034        self.assertEqual(self.android_dev_server.locate_file(
1035                file_name, artifacts, build, None), file_location)
1036        self.run_call_mock.assert_called_with(argument)
1037
1038    def testCmdErrorLocateFile(self):
1039        """Test locating files for AndriodBuildServer for retry
1040        error.CmdError, and raise urllib2.URLError."""
1041        dev_server.ImageServerBase.run_call.side_effect = CMD_ERROR
1042        dev_server.ImageServerBase.run_call.side_effect = E500
1043
1044        self.assertRaises(dev_server.DevServerException,
1045                          self.dev_server.trigger_download,
1046                          '')
1047
1048
1049    def testGetAvailableDevserversForCrashServer(self):
1050        """Test method get_available_devservers for CrashServer."""
1051        crash_servers = ['http://crash_servers1:8080']
1052        host = '127.0.0.1'
1053        with patch.object(dev_server.CrashServer, 'servers'):
1054            dev_server.CrashServer.servers.return_value = crash_servers
1055            self.assertEqual(
1056                    dev_server.CrashServer.get_available_devservers(host),
1057                    (crash_servers, False))
1058
1059
1060    def testGetAvailableDevserversForImageServer(self):
1061        """Test method get_available_devservers for ImageServer."""
1062        unrestricted_host = '100.0.0.99'
1063        unrestricted_servers = ['http://100.0.0.10:8080',
1064                                'http://128.0.0.10:8080']
1065        same_subnet_unrestricted_servers = ['http://100.0.0.10:8080']
1066        restricted_host = '127.0.0.99'
1067        restricted_servers = ['http://127.0.0.10:8080']
1068        all_servers = unrestricted_servers + restricted_servers
1069        # Set restricted subnets
1070        restricted_subnets = [('127.0.0.0', 24)]
1071
1072        with patch.object(dev_server.ImageServerBase, 'servers'):
1073            dev_server.ImageServerBase.servers.return_value = (all_servers)
1074
1075            # dut in unrestricted subnet shall be offered devserver in the same
1076            # subnet first, and allow retry.
1077            self.assertEqual(
1078                    dev_server.ImageServer.get_available_devservers(
1079                            unrestricted_host, True, restricted_subnets),
1080                    (same_subnet_unrestricted_servers, True))
1081
1082            # crbug.com/1027277: If prefer_local_devserver is set to False,
1083            # allow any devserver, and retry is not allowed.
1084            self.assertEqual(
1085                    dev_server.ImageServer.get_available_devservers(
1086                            unrestricted_host, False, restricted_subnets),
1087                    (all_servers, False))
1088
1089            # crbug.com/1027277: When no hostname is specified, all devservers
1090            # should be considered, and retry is not allowed.
1091            self.assertEqual(
1092                    dev_server.ImageServer.get_available_devservers(
1093                            None, True, restricted_subnets),
1094                    (all_servers, False))
1095
1096            # dut in restricted subnet should only be offered devserver in the
1097            # same restricted subnet, and retry is not allowed.
1098            self.assertEqual(
1099                    dev_server.ImageServer.get_available_devservers(
1100                            restricted_host, True, restricted_subnets),
1101                    (restricted_servers, False))
1102
1103
1104if __name__ == "__main__":
1105    unittest.main()
1106