• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1#!/usr/bin/python
2# Copyright 2016 The Chromium OS Authors. All rights reserved.
3# Use of this source code is governed by a BSD-style license that can be
4# found in the LICENSE file.
5
6import mock
7import unittest
8
9import common
10from autotest_lib.client.common_lib import error
11from autotest_lib.server.hosts import base_label_unittest, factory
12from autotest_lib.server.hosts import host_info
13
14
15class MockHost(object):
16    """Mock host object with no side effects."""
17    def __init__(self, hostname, **args):
18        self._init_args = args
19        self._init_args['hostname'] = hostname
20
21
22    def job_start(self):
23        """Only method called by factory."""
24        pass
25
26
27class MockConnectivity(object):
28    """Mock connectivity object with no side effects."""
29    def __init__(self, hostname, **args):
30        pass
31
32
33    def close(self):
34        """Only method called by factory."""
35        pass
36
37
38def _gen_mock_host(name, check_host=False):
39    """Create an identifiable mock host closs.
40    """
41    return type('mock_host_%s' % name, (MockHost,), {
42        '_host_cls_name': name,
43        'check_host': staticmethod(lambda host, timeout=None: check_host)
44    })
45
46
47def _gen_mock_conn(name):
48    """Create an identifiable mock connectivity class.
49    """
50    return type('mock_conn_%s' % name, (MockConnectivity,),
51                {'_conn_cls_name': name})
52
53
54def _gen_machine_dict(hostname='localhost', labels=[], attributes={}):
55    """Generate a machine dictionary with the specified parameters.
56
57    @param hostname: hostname of machine
58    @param labels: list of host labels
59    @param attributes: dict of host attributes
60
61    @return: machine dict with mocked AFE Host object and fake AfeStore.
62    """
63    afe_host = base_label_unittest.MockAFEHost(labels, attributes)
64    store = host_info.InMemoryHostInfoStore()
65    store.commit(host_info.HostInfo(labels, attributes))
66    return {'hostname': hostname,
67            'afe_host': afe_host,
68            'host_info_store': store}
69
70
71class CreateHostUnittests(unittest.TestCase):
72    """Tests for create_host function."""
73
74    def setUp(self):
75        """Prevent use of real Host and connectivity objects due to potential
76        side effects.
77        """
78        self._orig_ssh_engine = factory.SSH_ENGINE
79        self._orig_types = factory.host_types
80        self._orig_dict = factory.OS_HOST_DICT
81        self._orig_cros_host = factory.cros_host.CrosHost
82        self._orig_local_host = factory.local_host.LocalHost
83        self._orig_ssh_host = factory.ssh_host.SSHHost
84
85        self.host_types = factory.host_types = []
86        self.os_host_dict = factory.OS_HOST_DICT = {}
87        factory.cros_host.CrosHost = _gen_mock_host('cros_host')
88        factory.local_host.LocalHost = _gen_mock_conn('local')
89        factory.ssh_host.SSHHost = _gen_mock_conn('ssh')
90
91
92    def tearDown(self):
93        """Clean up mocks."""
94        factory.SSH_ENGINE = self._orig_ssh_engine
95        factory.host_types = self._orig_types
96        factory.OS_HOST_DICT = self._orig_dict
97        factory.cros_host.CrosHost = self._orig_cros_host
98        factory.local_host.LocalHost = self._orig_local_host
99        factory.ssh_host.SSHHost = self._orig_ssh_host
100
101
102    def test_use_specified(self):
103        """Confirm that the specified host and connectivity classes are used."""
104        machine = _gen_machine_dict()
105        host_obj = factory.create_host(
106                machine,
107                _gen_mock_host('specified'),
108                _gen_mock_conn('specified')
109        )
110        self.assertEqual(host_obj._host_cls_name, 'specified')
111        self.assertEqual(host_obj._conn_cls_name, 'specified')
112
113
114    def test_detect_host_by_os_label(self):
115        """Confirm that the host object is selected by the os label.
116        """
117        machine = _gen_machine_dict(labels=['os:foo'])
118        self.os_host_dict['foo'] = _gen_mock_host('foo')
119        host_obj = factory.create_host(machine)
120        self.assertEqual(host_obj._host_cls_name, 'foo')
121
122
123    def test_detect_host_by_os_type_attribute(self):
124        """Confirm that the host object is selected by the os_type attribute
125        and that the os_type attribute is preferred over the os label.
126        """
127        machine = _gen_machine_dict(labels=['os:foo'],
128                                         attributes={'os_type': 'bar'})
129        self.os_host_dict['foo'] = _gen_mock_host('foo')
130        self.os_host_dict['bar'] = _gen_mock_host('bar')
131        host_obj = factory.create_host(machine)
132        self.assertEqual(host_obj._host_cls_name, 'bar')
133
134
135    def test_detect_host_by_check_host(self):
136        """Confirm check_host logic chooses a host object when label/attribute
137        detection fails.
138        """
139        machine = _gen_machine_dict()
140        self.host_types.append(_gen_mock_host('first', check_host=False))
141        self.host_types.append(_gen_mock_host('second', check_host=True))
142        self.host_types.append(_gen_mock_host('third', check_host=False))
143        host_obj = factory.create_host(machine)
144        self.assertEqual(host_obj._host_cls_name, 'second')
145
146
147    def test_detect_host_fallback_to_cros_host(self):
148        """Confirm fallback to CrosHost when all other detection fails.
149        """
150        machine = _gen_machine_dict()
151        host_obj = factory.create_host(machine)
152        self.assertEqual(host_obj._host_cls_name, 'cros_host')
153
154
155    def test_choose_connectivity_local(self):
156        """Confirm local connectivity class used when hostname is localhost.
157        """
158        machine = _gen_machine_dict(hostname='localhost')
159        host_obj = factory.create_host(machine)
160        self.assertEqual(host_obj._conn_cls_name, 'local')
161
162
163    def test_choose_connectivity_ssh(self):
164        """Confirm ssh connectivity class used when configured and hostname
165        is not localhost.
166        """
167        factory.SSH_ENGINE = 'raw_ssh'
168        machine = _gen_machine_dict(hostname='somehost')
169        host_obj = factory.create_host(machine)
170        self.assertEqual(host_obj._conn_cls_name, 'ssh')
171
172
173    def test_choose_connectivity_unsupported(self):
174        """Confirm exception when configured for unsupported ssh engine.
175        """
176        factory.SSH_ENGINE = 'unsupported'
177        machine = _gen_machine_dict(hostname='somehost')
178        with self.assertRaises(error.AutoservError):
179            factory.create_host(machine)
180
181
182    def test_argument_passthrough(self):
183        """Confirm that detected and specified arguments are passed through to
184        the host object.
185        """
186        machine = _gen_machine_dict(hostname='localhost')
187        host_obj = factory.create_host(machine, foo='bar')
188        self.assertEqual(host_obj._init_args['hostname'], 'localhost')
189        self.assertTrue('afe_host' in host_obj._init_args)
190        self.assertTrue('host_info_store' in host_obj._init_args)
191        self.assertEqual(host_obj._init_args['foo'], 'bar')
192
193
194    def test_global_ssh_params(self):
195        """Confirm passing of ssh parameters set as globals.
196        """
197        factory.ssh_user = 'foo'
198        factory.ssh_pass = 'bar'
199        factory.ssh_port = 1
200        factory.ssh_verbosity_flag = 'baz'
201        factory.ssh_options = 'zip'
202        machine = _gen_machine_dict()
203        try:
204            host_obj = factory.create_host(machine)
205            self.assertEqual(host_obj._init_args['user'], 'foo')
206            self.assertEqual(host_obj._init_args['password'], 'bar')
207            self.assertEqual(host_obj._init_args['port'], 1)
208            self.assertEqual(host_obj._init_args['ssh_verbosity_flag'], 'baz')
209            self.assertEqual(host_obj._init_args['ssh_options'], 'zip')
210        finally:
211            del factory.ssh_user
212            del factory.ssh_pass
213            del factory.ssh_port
214            del factory.ssh_verbosity_flag
215            del factory.ssh_options
216
217
218    def test_host_attribute_ssh_params(self):
219        """Confirm passing of ssh parameters from host attributes.
220        """
221        machine = _gen_machine_dict(attributes={'ssh_user': 'somebody',
222                                                'ssh_port': 100,
223                                                'ssh_verbosity_flag': 'verb',
224                                                'ssh_options': 'options'})
225        host_obj = factory.create_host(machine)
226        self.assertEqual(host_obj._init_args['user'], 'somebody')
227        self.assertEqual(host_obj._init_args['port'], 100)
228        self.assertEqual(host_obj._init_args['ssh_verbosity_flag'], 'verb')
229        self.assertEqual(host_obj._init_args['ssh_options'], 'options')
230
231
232class CreateTestbedUnittests(unittest.TestCase):
233    """Tests for create_testbed function."""
234
235    def setUp(self):
236        """Mock out TestBed class to eliminate side effects.
237        """
238        self._orig_testbed = factory.testbed.TestBed
239        factory.testbed.TestBed = _gen_mock_host('testbed')
240
241
242    def tearDown(self):
243        """Clean up mock.
244        """
245        factory.testbed.TestBed = self._orig_testbed
246
247
248    def test_argument_passthrough(self):
249        """Confirm that detected and specified arguments are passed through to
250        the testbed object.
251        """
252        machine = _gen_machine_dict(hostname='localhost')
253        testbed_obj = factory.create_testbed(machine, foo='bar')
254        self.assertEqual(testbed_obj._init_args['hostname'], 'localhost')
255        self.assertTrue('afe_host' in testbed_obj._init_args)
256        self.assertTrue('host_info_store' in testbed_obj._init_args)
257        self.assertEqual(testbed_obj._init_args['foo'], 'bar')
258
259
260    def test_global_ssh_params(self):
261        """Confirm passing of ssh parameters set as globals.
262        """
263        factory.ssh_user = 'foo'
264        factory.ssh_pass = 'bar'
265        factory.ssh_port = 1
266        factory.ssh_verbosity_flag = 'baz'
267        factory.ssh_options = 'zip'
268        machine = _gen_machine_dict()
269        try:
270            testbed_obj = factory.create_testbed(machine)
271            self.assertEqual(testbed_obj._init_args['user'], 'foo')
272            self.assertEqual(testbed_obj._init_args['password'], 'bar')
273            self.assertEqual(testbed_obj._init_args['port'], 1)
274            self.assertEqual(testbed_obj._init_args['ssh_verbosity_flag'],
275                             'baz')
276            self.assertEqual(testbed_obj._init_args['ssh_options'], 'zip')
277        finally:
278            del factory.ssh_user
279            del factory.ssh_pass
280            del factory.ssh_port
281            del factory.ssh_verbosity_flag
282            del factory.ssh_options
283
284
285    def test_host_attribute_ssh_params(self):
286        """Confirm passing of ssh parameters from host attributes.
287        """
288        machine = _gen_machine_dict(attributes={'ssh_user': 'somebody',
289                                                'ssh_port': 100,
290                                                'ssh_verbosity_flag': 'verb',
291                                                'ssh_options': 'options'})
292        testbed_obj = factory.create_testbed(machine)
293        self.assertEqual(testbed_obj._init_args['user'], 'somebody')
294        self.assertEqual(testbed_obj._init_args['port'], 100)
295        self.assertEqual(testbed_obj._init_args['ssh_verbosity_flag'], 'verb')
296        self.assertEqual(testbed_obj._init_args['ssh_options'], 'options')
297
298
299if __name__ == '__main__':
300    unittest.main()
301
302