1# Copyright 2016 The Chromium OS Authors. All rights reserved. 2# Use of this source code is governed by a BSD-style license that can be 3# found in the LICENSE file. 4 5import cStringIO 6import inspect 7import json 8import unittest 9 10import common 11from autotest_lib.server.hosts import host_info 12 13 14class HostInfoTest(unittest.TestCase): 15 """Tests the non-trivial attributes of HostInfo.""" 16 17 def setUp(self): 18 self.info = host_info.HostInfo() 19 20 def test_info_comparison_to_wrong_type(self): 21 """Comparing HostInfo to a different type always returns False.""" 22 self.assertNotEqual(host_info.HostInfo(), 42) 23 self.assertNotEqual(host_info.HostInfo(), None) 24 # equality and non-equality are unrelated by the data model. 25 self.assertFalse(host_info.HostInfo() == 42) 26 self.assertFalse(host_info.HostInfo() == None) 27 28 29 def test_empty_infos_are_equal(self): 30 """Tests that empty HostInfo objects are considered equal.""" 31 self.assertEqual(host_info.HostInfo(), host_info.HostInfo()) 32 # equality and non-equality are unrelated by the data model. 33 self.assertFalse(host_info.HostInfo() != host_info.HostInfo()) 34 35 36 def test_non_trivial_infos_are_equal(self): 37 """Tests that the most complicated infos are correctly stated equal.""" 38 info1 = host_info.HostInfo( 39 labels=['label1', 'label2', 'label1'], 40 attributes={'attrib1': None, 'attrib2': 'val2'}, 41 stable_versions={"cros": "xxx-cros", "faft": "xxx-faft", "firmware": "xxx-firmware"},) 42 info2 = host_info.HostInfo( 43 labels=['label1', 'label2', 'label1'], 44 attributes={'attrib1': None, 'attrib2': 'val2'}, 45 stable_versions={"cros": "xxx-cros", "faft": "xxx-faft", "firmware": "xxx-firmware"},) 46 self.assertEqual(info1, info2) 47 # equality and non-equality are unrelated by the data model. 48 self.assertFalse(info1 != info2) 49 50 51 def test_non_equal_infos(self): 52 """Tests that HostInfo objects with different information are unequal""" 53 info1 = host_info.HostInfo(labels=['label']) 54 info2 = host_info.HostInfo(attributes={'attrib': 'value'}) 55 self.assertNotEqual(info1, info2) 56 # equality and non-equality are unrelated by the data model. 57 self.assertFalse(info1 == info2) 58 59 60 def test_build_needs_prefix(self): 61 """The build prefix is of the form '<type>-version:'""" 62 self.info.labels = ['cros-version', 'fwrw-version', 'fwro-version'] 63 self.assertIsNone(self.info.build) 64 65 66 def test_build_prefix_must_be_anchored(self): 67 """Ensure that build ignores prefixes occuring mid-string.""" 68 self.info.labels = ['not-at-start-cros-version:cros1'] 69 self.assertIsNone(self.info.build) 70 71 72 def test_build_ignores_firmware(self): 73 """build attribute should ignore firmware versions.""" 74 self.info.labels = ['fwrw-version:fwrw1', 'fwro-version:fwro1'] 75 self.assertIsNone(self.info.build) 76 77 78 def test_build_returns_first_match(self): 79 """When multiple labels match, first one should be used as build.""" 80 self.info.labels = ['cros-version:cros1', 'cros-version:cros2'] 81 self.assertEqual(self.info.build, 'cros1') 82 83 84 def test_build_prefer_cros_over_others(self): 85 """When multiple versions are available, prefer cros.""" 86 self.info.labels = ['cheets-version:ab1', 'cros-version:cros1'] 87 self.assertEqual(self.info.build, 'cros1') 88 self.info.labels = ['cros-version:cros1', 'cheets-version:ab1'] 89 self.assertEqual(self.info.build, 'cros1') 90 91 92 def test_os_no_match(self): 93 """Use proper prefix to search for os information.""" 94 self.info.labels = ['something_else', 'cros-version:hana', 95 'os_without_colon'] 96 self.assertEqual(self.info.os, '') 97 98 99 def test_os_returns_first_match(self): 100 """Return the first matching os label.""" 101 self.info.labels = ['os:linux', 'os:windows', 'os_corrupted_label'] 102 self.assertEqual(self.info.os, 'linux') 103 104 105 def test_board_no_match(self): 106 """Use proper prefix to search for board information.""" 107 self.info.labels = ['something_else', 'cros-version:hana', 'os:blah', 108 'board_my_board_no_colon'] 109 self.assertEqual(self.info.board, '') 110 111 112 def test_board_returns_first_match(self): 113 """Return the first matching board label.""" 114 self.info.labels = ['board_corrupted', 'board:walk', 'board:bored'] 115 self.assertEqual(self.info.board, 'walk') 116 117 118 def test_pools_no_match(self): 119 """Use proper prefix to search for pool information.""" 120 self.info.labels = ['something_else', 'cros-version:hana', 'os:blah', 121 'board_my_board_no_colon', 'board:my_board'] 122 self.assertEqual(self.info.pools, set()) 123 124 125 def test_pools_returns_all_matches(self): 126 """Return all matching pool labels.""" 127 self.info.labels = ['board_corrupted', 'board:walk', 'board:bored', 128 'pool:first_pool', 'pool:second_pool'] 129 self.assertEqual(self.info.pools, {'second_pool', 'first_pool'}) 130 131 132 def test_str(self): 133 """Sanity checks the __str__ implementation.""" 134 info = host_info.HostInfo(labels=['a'], attributes={'b': 2}) 135 self.assertEqual(str(info), 136 "HostInfo[Labels: ['a'], Attributes: {'b': 2}, StableVersions: {}]") 137 138 139 def test_clear_version_labels_no_labels(self): 140 """When no version labels exist, do nothing for clear_version_labels.""" 141 original_labels = ['board:something', 'os:something_else', 142 'pool:mypool', 'cheets-version-corrupted:blah', 143 'cros-version'] 144 self.info.labels = list(original_labels) 145 self.info.clear_version_labels() 146 self.assertListEqual(self.info.labels, original_labels) 147 148 149 def test_clear_all_version_labels(self): 150 """Clear each recognized type of version label.""" 151 original_labels = ['extra_label', 'cros-version:cr1', 152 'cheets-version:ab1'] 153 self.info.labels = list(original_labels) 154 self.info.clear_version_labels() 155 self.assertListEqual(self.info.labels, ['extra_label']) 156 157 def test_clear_all_version_label_prefixes(self): 158 """Clear each recognized type of version label with empty value.""" 159 original_labels = ['extra_label', 'cros-version:', 'cheets-version:'] 160 self.info.labels = list(original_labels) 161 self.info.clear_version_labels() 162 self.assertListEqual(self.info.labels, ['extra_label']) 163 164 165 def test_set_version_labels_updates_in_place(self): 166 """Update version label in place if prefix already exists.""" 167 self.info.labels = ['extra', 'cros-version:X', 'cheets-version:Y'] 168 self.info.set_version_label('cros-version', 'Z') 169 self.assertListEqual(self.info.labels, ['extra', 'cros-version:Z', 170 'cheets-version:Y']) 171 172 def test_set_version_labels_appends(self): 173 """Append a new version label if the prefix doesn't exist.""" 174 self.info.labels = ['extra', 'cheets-version:Y'] 175 self.info.set_version_label('cros-version', 'Z') 176 self.assertListEqual(self.info.labels, ['extra', 'cheets-version:Y', 177 'cros-version:Z']) 178 179 180class InMemoryHostInfoStoreTest(unittest.TestCase): 181 """Basic tests for CachingHostInfoStore using InMemoryHostInfoStore.""" 182 183 def setUp(self): 184 self.store = host_info.InMemoryHostInfoStore() 185 186 187 def _verify_host_info_data(self, host_info, labels, attributes): 188 """Verifies the data in the given host_info.""" 189 self.assertListEqual(host_info.labels, labels) 190 self.assertDictEqual(host_info.attributes, attributes) 191 192 193 def test_first_get_refreshes_cache(self): 194 """Test that the first call to get gets the data from store.""" 195 self.store.info = host_info.HostInfo(['label1'], {'attrib1': 'val1'}) 196 got = self.store.get() 197 self._verify_host_info_data(got, ['label1'], {'attrib1': 'val1'}) 198 199 200 def test_repeated_get_returns_from_cache(self): 201 """Tests that repeated calls to get do not refresh cache.""" 202 self.store.info = host_info.HostInfo(['label1'], {'attrib1': 'val1'}) 203 got = self.store.get() 204 self._verify_host_info_data(got, ['label1'], {'attrib1': 'val1'}) 205 206 self.store.info = host_info.HostInfo(['label1', 'label2'], {}) 207 got = self.store.get() 208 self._verify_host_info_data(got, ['label1'], {'attrib1': 'val1'}) 209 210 211 def test_get_uncached_always_refreshes_cache(self): 212 """Tests that calling get_uncached always refreshes the cache.""" 213 self.store.info = host_info.HostInfo(['label1'], {'attrib1': 'val1'}) 214 got = self.store.get(force_refresh=True) 215 self._verify_host_info_data(got, ['label1'], {'attrib1': 'val1'}) 216 217 self.store.info = host_info.HostInfo(['label1', 'label2'], {}) 218 got = self.store.get(force_refresh=True) 219 self._verify_host_info_data(got, ['label1', 'label2'], {}) 220 221 222 def test_commit(self): 223 """Test that commit sends data to store.""" 224 info = host_info.HostInfo(['label1'], {'attrib1': 'val1'}) 225 self._verify_host_info_data(self.store.info, [], {}) 226 self.store.commit(info) 227 self._verify_host_info_data(self.store.info, ['label1'], 228 {'attrib1': 'val1'}) 229 230 231 def test_commit_then_get(self): 232 """Test a commit-get roundtrip.""" 233 got = self.store.get() 234 self._verify_host_info_data(got, [], {}) 235 236 info = host_info.HostInfo(['label1'], {'attrib1': 'val1'}) 237 self.store.commit(info) 238 got = self.store.get() 239 self._verify_host_info_data(got, ['label1'], {'attrib1': 'val1'}) 240 241 242 def test_commit_then_get_uncached(self): 243 """Test a commit-get_uncached roundtrip.""" 244 got = self.store.get() 245 self._verify_host_info_data(got, [], {}) 246 247 info = host_info.HostInfo(['label1'], {'attrib1': 'val1'}) 248 self.store.commit(info) 249 got = self.store.get(force_refresh=True) 250 self._verify_host_info_data(got, ['label1'], {'attrib1': 'val1'}) 251 252 253 def test_commit_deepcopies_data(self): 254 """Once commited, changes to HostInfo don't corrupt the store.""" 255 info = host_info.HostInfo(['label1'], {'attrib1': {'key1': 'data1'}}) 256 self.store.commit(info) 257 info.labels.append('label2') 258 info.attributes['attrib1']['key1'] = 'data2' 259 self._verify_host_info_data(self.store.info, 260 ['label1'], {'attrib1': {'key1': 'data1'}}) 261 262 263 def test_get_returns_deepcopy(self): 264 """The cached object is protected from |get| caller modifications.""" 265 self.store.info = host_info.HostInfo(['label1'], 266 {'attrib1': {'key1': 'data1'}}) 267 got = self.store.get() 268 self._verify_host_info_data(got, 269 ['label1'], {'attrib1': {'key1': 'data1'}}) 270 got.labels.append('label2') 271 got.attributes['attrib1']['key1'] = 'data2' 272 got = self.store.get() 273 self._verify_host_info_data(got, 274 ['label1'], {'attrib1': {'key1': 'data1'}}) 275 276 277 def test_str(self): 278 """Sanity tests __str__ implementation.""" 279 self.store.info = host_info.HostInfo(['label1'], 280 {'attrib1': {'key1': 'data1'}}) 281 self.assertEqual(str(self.store), 282 'InMemoryHostInfoStore[%s]' % self.store.info) 283 284 285class ExceptionRaisingStore(host_info.CachingHostInfoStore): 286 """A test class that always raises on refresh / commit.""" 287 288 def __init__(self): 289 super(ExceptionRaisingStore, self).__init__() 290 self.refresh_raises = True 291 self.commit_raises = True 292 293 294 def _refresh_impl(self): 295 if self.refresh_raises: 296 raise host_info.StoreError('no can do') 297 return host_info.HostInfo() 298 299 def _commit_impl(self, _): 300 if self.commit_raises: 301 raise host_info.StoreError('wont wont wont') 302 303 304class CachingHostInfoStoreErrorTest(unittest.TestCase): 305 """Tests error behaviours of CachingHostInfoStore.""" 306 307 def setUp(self): 308 self.store = ExceptionRaisingStore() 309 310 311 def test_failed_refresh_cleans_cache(self): 312 """Sanity checks return values when refresh raises.""" 313 with self.assertRaises(host_info.StoreError): 314 self.store.get() 315 # Since |get| hit an error, a subsequent get should again hit the store. 316 with self.assertRaises(host_info.StoreError): 317 self.store.get() 318 319 320 def test_failed_commit_cleans_cache(self): 321 """Check that a failed commit cleanes cache.""" 322 # Let's initialize the store without errors. 323 self.store.refresh_raises = False 324 self.store.get(force_refresh=True) 325 self.store.refresh_raises = True 326 327 with self.assertRaises(host_info.StoreError): 328 self.store.commit(host_info.HostInfo()) 329 # Since |commit| hit an error, a subsequent get should again hit the 330 # store. 331 with self.assertRaises(host_info.StoreError): 332 self.store.get() 333 334 335class GetStoreFromMachineTest(unittest.TestCase): 336 """Tests the get_store_from_machine function.""" 337 338 def test_machine_is_dict(self): 339 """We extract the store when machine is a dict.""" 340 machine = { 341 'something': 'else', 342 'host_info_store': 5 343 } 344 self.assertEqual(host_info.get_store_from_machine(machine), 5) 345 346 347 def test_machine_is_string(self): 348 """We return a trivial store when machine is a string.""" 349 machine = 'hostname' 350 self.assertTrue(isinstance(host_info.get_store_from_machine(machine), 351 host_info.InMemoryHostInfoStore)) 352 353 354class HostInfoJsonSerializationTestCase(unittest.TestCase): 355 """Tests the json_serialize and json_deserialize functions.""" 356 357 CURRENT_SERIALIZATION_VERSION = host_info._CURRENT_SERIALIZATION_VERSION 358 359 def test_serialize_empty(self): 360 """Serializing empty HostInfo results in the expected json.""" 361 info = host_info.HostInfo() 362 file_obj = cStringIO.StringIO() 363 host_info.json_serialize(info, file_obj) 364 file_obj.seek(0) 365 expected_dict = { 366 'serializer_version': self.CURRENT_SERIALIZATION_VERSION, 367 'attributes' : {}, 368 'labels': [], 369 'stable_versions': {}, 370 } 371 self.assertEqual(json.load(file_obj), expected_dict) 372 373 374 def test_serialize_non_empty(self): 375 """Serializing a populated HostInfo results in expected json.""" 376 info = host_info.HostInfo(labels=['label1'], 377 attributes={'attrib': 'val'}, 378 stable_versions={'cros': 'xxx-cros'}) 379 file_obj = cStringIO.StringIO() 380 host_info.json_serialize(info, file_obj) 381 file_obj.seek(0) 382 expected_dict = { 383 'serializer_version': self.CURRENT_SERIALIZATION_VERSION, 384 'attributes' : {'attrib': 'val'}, 385 'labels': ['label1'], 386 'stable_versions': {'cros': 'xxx-cros'}, 387 } 388 self.assertEqual(json.load(file_obj), expected_dict) 389 390 391 def test_round_trip_empty(self): 392 """Serializing - deserializing empty HostInfo keeps it unchanged.""" 393 info = host_info.HostInfo() 394 serialized_fp = cStringIO.StringIO() 395 host_info.json_serialize(info, serialized_fp) 396 serialized_fp.seek(0) 397 got = host_info.json_deserialize(serialized_fp) 398 self.assertEqual(got, info) 399 400 401 def test_round_trip_non_empty(self): 402 """Serializing - deserializing non-empty HostInfo keeps it unchanged.""" 403 info = host_info.HostInfo( 404 labels=['label1'], 405 attributes = {'attrib': 'val'}) 406 serialized_fp = cStringIO.StringIO() 407 host_info.json_serialize(info, serialized_fp) 408 serialized_fp.seek(0) 409 got = host_info.json_deserialize(serialized_fp) 410 self.assertEqual(got, info) 411 412 413 def test_deserialize_malformed_json_raises(self): 414 """Deserializing a malformed string raises.""" 415 with self.assertRaises(host_info.DeserializationError): 416 host_info.json_deserialize(cStringIO.StringIO('{labels:[')) 417 418 419 def test_deserialize_malformed_host_info_raises(self): 420 """Deserializing a malformed host_info raises.""" 421 info = host_info.HostInfo() 422 serialized_fp = cStringIO.StringIO() 423 host_info.json_serialize(info, serialized_fp) 424 serialized_fp.seek(0) 425 426 serialized_dict = json.load(serialized_fp) 427 del serialized_dict['labels'] 428 serialized_no_version_str = json.dumps(serialized_dict) 429 430 with self.assertRaises(host_info.DeserializationError): 431 host_info.json_deserialize( 432 cStringIO.StringIO(serialized_no_version_str)) 433 434 435 def test_enforce_compatibility_version_2(self): 436 """Tests that required fields are never dropped. 437 438 Never change this test. If you must break compatibility, uprev the 439 serializer version and add a new test for the newer version. 440 441 Adding a field to compat_info_str means we're making the new field 442 mandatory. This breaks backwards compatibility. 443 Removing a field from compat_info_str means we're no longer requiring a 444 field to be mandatory. This breaks forwards compatibility. 445 """ 446 compat_dict = { 447 'serializer_version': 2, 448 'attributes': {}, 449 'labels': [] 450 } 451 serialized_str = json.dumps(compat_dict) 452 serialized_fp = cStringIO.StringIO(serialized_str) 453 host_info.json_deserialize(serialized_fp) 454 455 456 def test_serialize_pretty_print(self): 457 """Serializing a host_info dumps the json in human-friendly format""" 458 info = host_info.HostInfo(labels=['label1'], 459 attributes={'attrib': 'val'}) 460 serialized_fp = cStringIO.StringIO() 461 host_info.json_serialize(info, serialized_fp) 462 expected = """{ 463 "attributes": { 464 "attrib": "val" 465 }, 466 "labels": [ 467 "label1" 468 ], 469 "serializer_version": %d, 470 "stable_versions": {} 471 }""" % self.CURRENT_SERIALIZATION_VERSION 472 self.assertEqual(serialized_fp.getvalue(), inspect.cleandoc(expected)) 473 474 475if __name__ == '__main__': 476 unittest.main() 477