1#!/usr/bin/python2 2# Copyright 2016 The Chromium OS Authors. All rights reserved. 3# Use of this source code is governed by a BSD-style license that can be 4# found in the LICENSE file. 5 6import unittest 7 8import common 9from autotest_lib.client.common_lib import error 10from autotest_lib.server.hosts import base_label_unittest, factory 11from autotest_lib.server.hosts import host_info 12 13 14class MockHost(object): 15 """Mock host object with no side effects.""" 16 def __init__(self, hostname, **args): 17 self._init_args = args 18 self._init_args['hostname'] = hostname 19 20 21 def job_start(self): 22 """Only method called by factory.""" 23 pass 24 25 26class MockConnectivity(object): 27 """Mock connectivity object with no side effects.""" 28 def __init__(self, hostname, **args): 29 pass 30 31 def run(self, *args, **kwargs): 32 pass 33 34 def close(self): 35 pass 36 37 38def _gen_mock_host(name, check_host=False): 39 """Create an identifiable mock host closs. 40 """ 41 return type('mock_host_%s' % name, (MockHost,), { 42 '_host_cls_name': name, 43 'check_host': staticmethod(lambda host, timeout=None: check_host) 44 }) 45 46 47def _gen_mock_conn(name): 48 """Create an identifiable mock connectivity class. 49 """ 50 return type('mock_conn_%s' % name, (MockConnectivity,), 51 {'_conn_cls_name': name}) 52 53 54def _gen_machine_dict(hostname='localhost', labels=[], attributes={}): 55 """Generate a machine dictionary with the specified parameters. 56 57 @param hostname: hostname of machine 58 @param labels: list of host labels 59 @param attributes: dict of host attributes 60 61 @return: machine dict with mocked AFE Host object and fake AfeStore. 62 """ 63 afe_host = base_label_unittest.MockAFEHost(labels, attributes) 64 store = host_info.InMemoryHostInfoStore() 65 store.commit(host_info.HostInfo(labels, attributes)) 66 return {'hostname': hostname, 67 'afe_host': afe_host, 68 'host_info_store': store} 69 70 71class CreateHostUnittests(unittest.TestCase): 72 """Tests for create_host function.""" 73 74 def setUp(self): 75 """Prevent use of real Host and connectivity objects due to potential 76 side effects. 77 """ 78 self._orig_types = factory.host_types 79 self._orig_dict = factory.OS_HOST_DICT 80 self._orig_cros_host = factory.cros_host.CrosHost 81 self._orig_local_host = factory.local_host.LocalHost 82 self._orig_ssh_host = factory.ssh_host.SSHHost 83 84 self.host_types = factory.host_types = [] 85 self.os_host_dict = factory.OS_HOST_DICT = {} 86 factory.cros_host.CrosHost = _gen_mock_host('cros_host') 87 factory.local_host.LocalHost = _gen_mock_conn('local') 88 factory.ssh_host.SSHHost = _gen_mock_conn('ssh') 89 90 91 def tearDown(self): 92 """Clean up mocks.""" 93 factory.host_types = self._orig_types 94 factory.OS_HOST_DICT = self._orig_dict 95 factory.cros_host.CrosHost = self._orig_cros_host 96 factory.local_host.LocalHost = self._orig_local_host 97 factory.ssh_host.SSHHost = self._orig_ssh_host 98 99 100 def test_use_specified(self): 101 """Confirm that the specified host class is used.""" 102 machine = _gen_machine_dict() 103 host_obj = factory.create_host( 104 machine, 105 _gen_mock_host('specified'), 106 ) 107 self.assertEqual(host_obj._host_cls_name, 'specified') 108 109 110 def test_detect_host_by_os_label(self): 111 """Confirm that the host object is selected by the os label. 112 """ 113 machine = _gen_machine_dict(labels=['os:foo']) 114 self.os_host_dict['foo'] = _gen_mock_host('foo') 115 host_obj = factory.create_host(machine) 116 self.assertEqual(host_obj._host_cls_name, 'foo') 117 118 119 def test_detect_host_by_os_type_attribute(self): 120 """Confirm that the host object is selected by the os_type attribute 121 and that the os_type attribute is preferred over the os label. 122 """ 123 machine = _gen_machine_dict(labels=['os:foo'], 124 attributes={'os_type': 'bar'}) 125 self.os_host_dict['foo'] = _gen_mock_host('foo') 126 self.os_host_dict['bar'] = _gen_mock_host('bar') 127 host_obj = factory.create_host(machine) 128 self.assertEqual(host_obj._host_cls_name, 'bar') 129 130 131 def test_detect_host_by_check_host(self): 132 """Confirm check_host logic chooses a host object when label/attribute 133 detection fails. 134 """ 135 machine = _gen_machine_dict() 136 self.host_types.append(_gen_mock_host('first', check_host=False)) 137 self.host_types.append(_gen_mock_host('second', check_host=True)) 138 self.host_types.append(_gen_mock_host('third', check_host=False)) 139 host_obj = factory.create_host(machine) 140 self.assertEqual(host_obj._host_cls_name, 'second') 141 142 143 def test_detect_host_fallback_to_cros_host(self): 144 """Confirm fallback to CrosHost when all other detection fails. 145 """ 146 machine = _gen_machine_dict() 147 host_obj = factory.create_host(machine) 148 self.assertEqual(host_obj._host_cls_name, 'cros_host') 149 150 151 def test_choose_connectivity_local(self): 152 """Confirm local connectivity class used when hostname is localhost. 153 """ 154 machine = _gen_machine_dict(hostname='localhost') 155 host_obj = factory.create_host(machine) 156 self.assertEqual(host_obj._conn_cls_name, 'local') 157 158 159 def test_choose_connectivity_ssh(self): 160 """Confirm ssh connectivity class used when configured and hostname 161 is not localhost. 162 """ 163 machine = _gen_machine_dict(hostname='somehost') 164 host_obj = factory.create_host(machine) 165 self.assertEqual(host_obj._conn_cls_name, 'ssh') 166 167 168 def test_argument_passthrough(self): 169 """Confirm that detected and specified arguments are passed through to 170 the host object. 171 """ 172 machine = _gen_machine_dict(hostname='localhost') 173 host_obj = factory.create_host(machine, foo='bar') 174 self.assertEqual(host_obj._init_args['hostname'], 'localhost') 175 self.assertTrue('afe_host' in host_obj._init_args) 176 self.assertTrue('host_info_store' in host_obj._init_args) 177 self.assertEqual(host_obj._init_args['foo'], 'bar') 178 179 180 def test_global_ssh_params(self): 181 """Confirm passing of ssh parameters set as globals. 182 """ 183 factory.ssh_user = 'foo' 184 factory.ssh_pass = 'bar' 185 factory.ssh_port = 1 186 factory.ssh_verbosity_flag = 'baz' 187 factory.ssh_options = 'zip' 188 machine = _gen_machine_dict() 189 try: 190 host_obj = factory.create_host(machine) 191 self.assertEqual(host_obj._init_args['user'], 'foo') 192 self.assertEqual(host_obj._init_args['password'], 'bar') 193 self.assertEqual(host_obj._init_args['port'], 1) 194 self.assertEqual(host_obj._init_args['ssh_verbosity_flag'], 'baz') 195 self.assertEqual(host_obj._init_args['ssh_options'], 'zip') 196 finally: 197 del factory.ssh_user 198 del factory.ssh_pass 199 del factory.ssh_port 200 del factory.ssh_verbosity_flag 201 del factory.ssh_options 202 203 204 def test_host_attribute_ssh_params(self): 205 """Confirm passing of ssh parameters from host attributes. 206 """ 207 machine = _gen_machine_dict(attributes={'ssh_user': 'somebody', 208 'ssh_port': 100, 209 'ssh_verbosity_flag': 'verb', 210 'ssh_options': 'options'}) 211 host_obj = factory.create_host(machine) 212 self.assertEqual(host_obj._init_args['user'], 'somebody') 213 self.assertEqual(host_obj._init_args['port'], 100) 214 self.assertEqual(host_obj._init_args['ssh_verbosity_flag'], 'verb') 215 self.assertEqual(host_obj._init_args['ssh_options'], 'options') 216 217 218if __name__ == '__main__': 219 unittest.main() 220 221