• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Lint as: python2, python3
2# Copyright 2016 The Chromium OS Authors. All rights reserved.
3# Use of this source code is governed by a BSD-style license that can be
4# found in the LICENSE file.
5
6import six
7import inspect
8import json
9import unittest
10
11import common
12from autotest_lib.server.hosts import host_info
13
14
15class HostInfoTest(unittest.TestCase):
16    """Tests the non-trivial attributes of HostInfo."""
17
18    def setUp(self):
19        self.info = host_info.HostInfo()
20
21    def test_info_comparison_to_wrong_type(self):
22        """Comparing HostInfo to a different type always returns False."""
23        self.assertNotEqual(host_info.HostInfo(), 42)
24        self.assertNotEqual(host_info.HostInfo(), None)
25        # equality and non-equality are unrelated by the data model.
26        self.assertFalse(host_info.HostInfo() == 42)
27        self.assertFalse(host_info.HostInfo() == None)
28
29
30    def test_empty_infos_are_equal(self):
31        """Tests that empty HostInfo objects are considered equal."""
32        self.assertEqual(host_info.HostInfo(), host_info.HostInfo())
33        # equality and non-equality are unrelated by the data model.
34        self.assertFalse(host_info.HostInfo() != host_info.HostInfo())
35
36
37    def test_non_trivial_infos_are_equal(self):
38        """Tests that the most complicated infos are correctly stated equal."""
39        info1 = host_info.HostInfo(
40                labels=['label1', 'label2', 'label1'],
41                attributes={'attrib1': None, 'attrib2': 'val2'},
42                stable_versions={"cros": "xxx-cros", "faft": "xxx-faft", "firmware": "xxx-firmware"},)
43        info2 = host_info.HostInfo(
44                labels=['label1', 'label2', 'label1'],
45                attributes={'attrib1': None, 'attrib2': 'val2'},
46                stable_versions={"cros": "xxx-cros", "faft": "xxx-faft", "firmware": "xxx-firmware"},)
47        self.assertEqual(info1, info2)
48        # equality and non-equality are unrelated by the data model.
49        self.assertFalse(info1 != info2)
50
51
52    def test_non_equal_infos(self):
53        """Tests that HostInfo objects with different information are unequal"""
54        info1 = host_info.HostInfo(labels=['label'])
55        info2 = host_info.HostInfo(attributes={'attrib': 'value'})
56        self.assertNotEqual(info1, info2)
57        # equality and non-equality are unrelated by the data model.
58        self.assertFalse(info1 == info2)
59
60
61    def test_build_needs_prefix(self):
62        """The build prefix is of the form '<type>-version:'"""
63        self.info.labels = ['cros-version', 'fwrw-version', 'fwro-version']
64        self.assertIsNone(self.info.build)
65
66
67    def test_build_prefix_must_be_anchored(self):
68        """Ensure that build ignores prefixes occuring mid-string."""
69        self.info.labels = ['not-at-start-cros-version:cros1']
70        self.assertIsNone(self.info.build)
71
72
73    def test_build_ignores_firmware(self):
74        """build attribute should ignore firmware versions."""
75        self.info.labels = ['fwrw-version:fwrw1', 'fwro-version:fwro1']
76        self.assertIsNone(self.info.build)
77
78
79    def test_build_returns_first_match(self):
80        """When multiple labels match, first one should be used as build."""
81        self.info.labels = ['cros-version:cros1', 'cros-version:cros2']
82        self.assertEqual(self.info.build, 'cros1')
83
84
85    def test_build_prefer_cros_over_others(self):
86        """When multiple versions are available, prefer cros."""
87        self.info.labels = ['cheets-version:ab1', 'cros-version:cros1']
88        self.assertEqual(self.info.build, 'cros1')
89        self.info.labels = ['cros-version:cros1', 'cheets-version:ab1']
90        self.assertEqual(self.info.build, 'cros1')
91
92
93    def test_os_no_match(self):
94        """Use proper prefix to search for os information."""
95        self.info.labels = ['something_else', 'cros-version:hana',
96                            'os_without_colon']
97        self.assertEqual(self.info.os, '')
98
99
100    def test_os_returns_first_match(self):
101        """Return the first matching os label."""
102        self.info.labels = ['os:linux', 'os:windows', 'os_corrupted_label']
103        self.assertEqual(self.info.os, 'linux')
104
105
106    def test_board_no_match(self):
107        """Use proper prefix to search for board information."""
108        self.info.labels = ['something_else', 'cros-version:hana', 'os:blah',
109                            'board_my_board_no_colon']
110        self.assertEqual(self.info.board, '')
111
112
113    def test_board_returns_first_match(self):
114        """Return the first matching board label."""
115        self.info.labels = ['board_corrupted', 'board:walk', 'board:bored']
116        self.assertEqual(self.info.board, 'walk')
117
118
119    def test_pools_no_match(self):
120        """Use proper prefix to search for pool information."""
121        self.info.labels = ['something_else', 'cros-version:hana', 'os:blah',
122                            'board_my_board_no_colon', 'board:my_board']
123        self.assertEqual(self.info.pools, set())
124
125
126    def test_pools_returns_all_matches(self):
127        """Return all matching pool labels."""
128        self.info.labels = ['board_corrupted', 'board:walk', 'board:bored',
129                            'pool:first_pool', 'pool:second_pool']
130        self.assertEqual(self.info.pools, {'second_pool', 'first_pool'})
131
132
133    def test_str(self):
134        """Sanity checks the __str__ implementation."""
135        info = host_info.HostInfo(labels=['a'], attributes={'b': 2})
136        self.assertEqual(str(info),
137                         "HostInfo[Labels: ['a'], Attributes: {'b': 2}, StableVersions: {}]")
138
139
140    def test_clear_version_labels_no_labels(self):
141        """When no version labels exist, do nothing for clear_version_labels."""
142        original_labels = ['board:something', 'os:something_else',
143                           'pool:mypool', 'cheets-version-corrupted:blah',
144                           'cros-version']
145        self.info.labels = list(original_labels)
146        self.info.clear_version_labels()
147        self.assertListEqual(self.info.labels, original_labels)
148
149
150    def test_clear_all_version_labels(self):
151        """Clear each recognized type of version label."""
152        original_labels = ['extra_label', 'cros-version:cr1',
153                           'cheets-version:ab1']
154        self.info.labels = list(original_labels)
155        self.info.clear_version_labels()
156        self.assertListEqual(self.info.labels, ['extra_label'])
157
158    def test_clear_all_version_label_prefixes(self):
159        """Clear each recognized type of version label with empty value."""
160        original_labels = ['extra_label', 'cros-version:', 'cheets-version:']
161        self.info.labels = list(original_labels)
162        self.info.clear_version_labels()
163        self.assertListEqual(self.info.labels, ['extra_label'])
164
165
166    def test_set_version_labels_updates_in_place(self):
167        """Update version label in place if prefix already exists."""
168        self.info.labels = ['extra', 'cros-version:X', 'cheets-version:Y']
169        self.info.set_version_label('cros-version', 'Z')
170        self.assertListEqual(self.info.labels, ['extra', 'cros-version:Z',
171                                                'cheets-version:Y'])
172
173    def test_set_version_labels_appends(self):
174        """Append a new version label if the prefix doesn't exist."""
175        self.info.labels = ['extra', 'cheets-version:Y']
176        self.info.set_version_label('cros-version', 'Z')
177        self.assertListEqual(self.info.labels, ['extra', 'cheets-version:Y',
178                                                'cros-version:Z'])
179
180    def test_has_level_as_prefix(self):
181        """Check if label present as prefix with some value."""
182        self.info.labels = ['lb1', 'lb2:Y']
183        self.assertTrue(self.info.has_label('lb2'))
184        self.info.labels = ['lb1', 'lb2:']
185        self.assertTrue(self.info.has_label('lb2'))
186
187    def test_has_level_as_value(self):
188        """Check if label present as value."""
189        self.info.labels = ['lb1', 'lb2:Y']
190        self.assertTrue(self.info.has_label('lb1'))
191
192    def test_has_level_is_not_present(self):
193        """Check if label present as value."""
194        self.info.labels = ['lb1', 'lb2:Y']
195        self.assertFalse(self.info.has_label('lb3'))
196        self.assertFalse(self.info.has_label('LB1'))
197
198
199class InMemoryHostInfoStoreTest(unittest.TestCase):
200    """Basic tests for CachingHostInfoStore using InMemoryHostInfoStore."""
201
202    def setUp(self):
203        self.store = host_info.InMemoryHostInfoStore()
204
205
206    def _verify_host_info_data(self, host_info, labels, attributes):
207        """Verifies the data in the given host_info."""
208        self.assertListEqual(host_info.labels, labels)
209        self.assertDictEqual(host_info.attributes, attributes)
210
211
212    def test_first_get_refreshes_cache(self):
213        """Test that the first call to get gets the data from store."""
214        self.store.info = host_info.HostInfo(['label1'], {'attrib1': 'val1'})
215        got = self.store.get()
216        self._verify_host_info_data(got, ['label1'], {'attrib1': 'val1'})
217
218
219    def test_repeated_get_returns_from_cache(self):
220        """Tests that repeated calls to get do not refresh cache."""
221        self.store.info = host_info.HostInfo(['label1'], {'attrib1': 'val1'})
222        got = self.store.get()
223        self._verify_host_info_data(got, ['label1'], {'attrib1': 'val1'})
224
225        self.store.info = host_info.HostInfo(['label1', 'label2'], {})
226        got = self.store.get()
227        self._verify_host_info_data(got, ['label1'], {'attrib1': 'val1'})
228
229
230    def test_get_uncached_always_refreshes_cache(self):
231        """Tests that calling get_uncached always refreshes the cache."""
232        self.store.info = host_info.HostInfo(['label1'], {'attrib1': 'val1'})
233        got = self.store.get(force_refresh=True)
234        self._verify_host_info_data(got, ['label1'], {'attrib1': 'val1'})
235
236        self.store.info = host_info.HostInfo(['label1', 'label2'], {})
237        got = self.store.get(force_refresh=True)
238        self._verify_host_info_data(got, ['label1', 'label2'], {})
239
240
241    def test_commit(self):
242        """Test that commit sends data to store."""
243        info = host_info.HostInfo(['label1'], {'attrib1': 'val1'})
244        self._verify_host_info_data(self.store.info, [], {})
245        self.store.commit(info)
246        self._verify_host_info_data(self.store.info, ['label1'],
247                                    {'attrib1': 'val1'})
248
249
250    def test_commit_then_get(self):
251        """Test a commit-get roundtrip."""
252        got = self.store.get()
253        self._verify_host_info_data(got, [], {})
254
255        info = host_info.HostInfo(['label1'], {'attrib1': 'val1'})
256        self.store.commit(info)
257        got = self.store.get()
258        self._verify_host_info_data(got, ['label1'], {'attrib1': 'val1'})
259
260
261    def test_commit_then_get_uncached(self):
262        """Test a commit-get_uncached roundtrip."""
263        got = self.store.get()
264        self._verify_host_info_data(got, [], {})
265
266        info = host_info.HostInfo(['label1'], {'attrib1': 'val1'})
267        self.store.commit(info)
268        got = self.store.get(force_refresh=True)
269        self._verify_host_info_data(got, ['label1'], {'attrib1': 'val1'})
270
271
272    def test_commit_deepcopies_data(self):
273        """Once commited, changes to HostInfo don't corrupt the store."""
274        info = host_info.HostInfo(['label1'], {'attrib1': {'key1': 'data1'}})
275        self.store.commit(info)
276        info.labels.append('label2')
277        info.attributes['attrib1']['key1'] = 'data2'
278        self._verify_host_info_data(self.store.info,
279                                    ['label1'], {'attrib1': {'key1': 'data1'}})
280
281
282    def test_get_returns_deepcopy(self):
283        """The cached object is protected from |get| caller modifications."""
284        self.store.info = host_info.HostInfo(['label1'],
285                                             {'attrib1': {'key1': 'data1'}})
286        got = self.store.get()
287        self._verify_host_info_data(got,
288                                    ['label1'], {'attrib1': {'key1': 'data1'}})
289        got.labels.append('label2')
290        got.attributes['attrib1']['key1'] = 'data2'
291        got = self.store.get()
292        self._verify_host_info_data(got,
293                                    ['label1'], {'attrib1': {'key1': 'data1'}})
294
295
296    def test_str(self):
297        """Sanity tests __str__ implementation."""
298        self.store.info = host_info.HostInfo(['label1'],
299                                             {'attrib1': {'key1': 'data1'}})
300        self.assertEqual(str(self.store),
301                         'InMemoryHostInfoStore[%s]' % self.store.info)
302
303
304class ExceptionRaisingStore(host_info.CachingHostInfoStore):
305    """A test class that always raises on refresh / commit."""
306
307    def __init__(self):
308        super(ExceptionRaisingStore, self).__init__()
309        self.refresh_raises = True
310        self.commit_raises = True
311
312
313    def _refresh_impl(self):
314        if self.refresh_raises:
315            raise host_info.StoreError('no can do')
316        return host_info.HostInfo()
317
318    def _commit_impl(self, _):
319        if self.commit_raises:
320            raise host_info.StoreError('wont wont wont')
321
322
323class CachingHostInfoStoreErrorTest(unittest.TestCase):
324    """Tests error behaviours of CachingHostInfoStore."""
325
326    def setUp(self):
327        self.store = ExceptionRaisingStore()
328
329
330    def test_failed_refresh_cleans_cache(self):
331        """Sanity checks return values when refresh raises."""
332        with self.assertRaises(host_info.StoreError):
333            self.store.get()
334        # Since |get| hit an error, a subsequent get should again hit the store.
335        with self.assertRaises(host_info.StoreError):
336            self.store.get()
337
338
339    def test_failed_commit_cleans_cache(self):
340        """Check that a failed commit cleanes cache."""
341        # Let's initialize the store without errors.
342        self.store.refresh_raises = False
343        self.store.get(force_refresh=True)
344        self.store.refresh_raises = True
345
346        with self.assertRaises(host_info.StoreError):
347            self.store.commit(host_info.HostInfo())
348        # Since |commit| hit an error, a subsequent get should again hit the
349        # store.
350        with self.assertRaises(host_info.StoreError):
351            self.store.get()
352
353
354class GetStoreFromMachineTest(unittest.TestCase):
355    """Tests the get_store_from_machine function."""
356
357    def test_machine_is_dict(self):
358        """We extract the store when machine is a dict."""
359        machine = {
360                'something': 'else',
361                'host_info_store': 5
362        }
363        self.assertEqual(host_info.get_store_from_machine(machine), 5)
364
365
366    def test_machine_is_string(self):
367        """We return a trivial store when machine is a string."""
368        machine = 'hostname'
369        self.assertTrue(isinstance(host_info.get_store_from_machine(machine),
370                                   host_info.InMemoryHostInfoStore))
371
372
373class HostInfoJsonSerializationTestCase(unittest.TestCase):
374    """Tests the json_serialize and json_deserialize functions."""
375
376    CURRENT_SERIALIZATION_VERSION = host_info._CURRENT_SERIALIZATION_VERSION
377
378    def test_serialize_empty(self):
379        """Serializing empty HostInfo results in the expected json."""
380        info = host_info.HostInfo()
381        file_obj = six.StringIO()
382        host_info.json_serialize(info, file_obj)
383        file_obj.seek(0)
384        expected_dict = {
385                'serializer_version': self.CURRENT_SERIALIZATION_VERSION,
386                'attributes' : {},
387                'labels': [],
388                'stable_versions': {},
389        }
390        self.assertEqual(json.load(file_obj), expected_dict)
391
392
393    def test_serialize_non_empty(self):
394        """Serializing a populated HostInfo results in expected json."""
395        info = host_info.HostInfo(labels=['label1'],
396                                  attributes={'attrib': 'val'},
397                                  stable_versions={'cros': 'xxx-cros'})
398        file_obj = six.StringIO()
399        host_info.json_serialize(info, file_obj)
400        file_obj.seek(0)
401        expected_dict = {
402                'serializer_version': self.CURRENT_SERIALIZATION_VERSION,
403                'attributes' : {'attrib': 'val'},
404                'labels': ['label1'],
405                'stable_versions': {'cros': 'xxx-cros'},
406        }
407        self.assertEqual(json.load(file_obj), expected_dict)
408
409
410    def test_round_trip_empty(self):
411        """Serializing - deserializing empty HostInfo keeps it unchanged."""
412        info = host_info.HostInfo()
413        serialized_fp = six.StringIO()
414        host_info.json_serialize(info, serialized_fp)
415        serialized_fp.seek(0)
416        got = host_info.json_deserialize(serialized_fp)
417        self.assertEqual(got, info)
418
419
420    def test_round_trip_non_empty(self):
421        """Serializing - deserializing non-empty HostInfo keeps it unchanged."""
422        info = host_info.HostInfo(
423                labels=['label1'],
424                attributes = {'attrib': 'val'})
425        serialized_fp = six.StringIO()
426        host_info.json_serialize(info, serialized_fp)
427        serialized_fp.seek(0)
428        got = host_info.json_deserialize(serialized_fp)
429        self.assertEqual(got, info)
430
431
432    def test_deserialize_malformed_json_raises(self):
433        """Deserializing a malformed string raises."""
434        with self.assertRaises(host_info.DeserializationError):
435            host_info.json_deserialize(six.StringIO('{labels:['))
436
437
438    def test_deserialize_malformed_host_info_raises(self):
439        """Deserializing a malformed host_info raises."""
440        info = host_info.HostInfo()
441        serialized_fp = six.StringIO()
442        host_info.json_serialize(info, serialized_fp)
443        serialized_fp.seek(0)
444
445        serialized_dict = json.load(serialized_fp)
446        del serialized_dict['labels']
447        serialized_no_version_str = json.dumps(serialized_dict)
448
449        with self.assertRaises(host_info.DeserializationError):
450            host_info.json_deserialize(
451                    six.StringIO(serialized_no_version_str))
452
453
454    def test_enforce_compatibility_version_2(self):
455        """Tests that required fields are never dropped.
456
457        Never change this test. If you must break compatibility, uprev the
458        serializer version and add a new test for the newer version.
459
460        Adding a field to compat_info_str means we're making the new field
461        mandatory. This breaks backwards compatibility.
462        Removing a field from compat_info_str means we're no longer requiring a
463        field to be mandatory. This breaks forwards compatibility.
464        """
465        compat_dict = {
466                'serializer_version': 2,
467                'attributes': {},
468                'labels': []
469        }
470        serialized_str = json.dumps(compat_dict)
471        serialized_fp = six.StringIO(serialized_str)
472        host_info.json_deserialize(serialized_fp)
473
474
475    def test_serialize_pretty_print(self):
476        """Serializing a host_info dumps the json in human-friendly format"""
477        info = host_info.HostInfo(labels=['label1'],
478                                  attributes={'attrib': 'val'})
479        serialized_fp = six.StringIO()
480        host_info.json_serialize(info, serialized_fp)
481        expected = """{
482            "attributes": {
483                "attrib": "val"
484            },
485            "labels": [
486                "label1"
487            ],
488            "serializer_version": %d,
489            "stable_versions": {}
490        }""" % self.CURRENT_SERIALIZATION_VERSION
491        self.assertEqual(serialized_fp.getvalue(), inspect.cleandoc(expected))
492
493
494if __name__ == '__main__':
495    unittest.main()
496