• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1#!/usr/bin/env python
2#
3# Copyright 2016 - The Android Open Source Project
4#
5# Licensed under the Apache License, Version 2.0 (the "License");
6# you may not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9#     http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an "AS IS" BASIS,
13# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
16"""Tests for acloud.internal.lib.utils."""
17
18import errno
19import getpass
20import grp
21import os
22import shutil
23import subprocess
24import tempfile
25import time
26
27import unittest
28import mock
29
30from acloud import errors
31from acloud.internal.lib import driver_test_lib
32from acloud.internal.lib import utils
33
34# Tkinter may not be supported so mock it out.
35try:
36    import Tkinter
37except ImportError:
38    Tkinter = mock.Mock()
39
40class FakeTkinter(object):
41    """Fake implementation of Tkinter.Tk()"""
42
43    def __init__(self, width=None, height=None):
44        self.width = width
45        self.height = height
46
47    # pylint: disable=invalid-name
48    def winfo_screenheight(self):
49        """Return the screen height."""
50        return self.height
51
52    # pylint: disable=invalid-name
53    def winfo_screenwidth(self):
54        """Return the screen width."""
55        return self.width
56
57
58# pylint: disable=too-many-public-methods
59class UtilsTest(driver_test_lib.BaseDriverTest):
60    """Test Utils."""
61
62    def TestTempDirSuccess(self):
63        """Test create a temp dir."""
64        self.Patch(os, "chmod")
65        self.Patch(tempfile, "mkdtemp", return_value="/tmp/tempdir")
66        self.Patch(shutil, "rmtree")
67        with utils.TempDir():
68            pass
69        # Verify.
70        tempfile.mkdtemp.assert_called_once()  # pylint: disable=no-member
71        shutil.rmtree.assert_called_with("/tmp/tempdir")  # pylint: disable=no-member
72
73    def TestTempDirExceptionRaised(self):
74        """Test create a temp dir and exception is raised within with-clause."""
75        self.Patch(os, "chmod")
76        self.Patch(tempfile, "mkdtemp", return_value="/tmp/tempdir")
77        self.Patch(shutil, "rmtree")
78
79        class ExpectedException(Exception):
80            """Expected exception."""
81            pass
82
83        def _Call():
84            with utils.TempDir():
85                raise ExpectedException("Expected exception.")
86
87        # Verify. ExpectedException should be raised.
88        self.assertRaises(ExpectedException, _Call)
89        tempfile.mkdtemp.assert_called_once()  # pylint: disable=no-member
90        shutil.rmtree.assert_called_with("/tmp/tempdir")  #pylint: disable=no-member
91
92    def testTempDirWhenDeleteTempDirNoLongerExist(self):  # pylint: disable=invalid-name
93        """Test create a temp dir and dir no longer exists during deletion."""
94        self.Patch(os, "chmod")
95        self.Patch(tempfile, "mkdtemp", return_value="/tmp/tempdir")
96        expected_error = EnvironmentError()
97        expected_error.errno = errno.ENOENT
98        self.Patch(shutil, "rmtree", side_effect=expected_error)
99
100        def _Call():
101            with utils.TempDir():
102                pass
103
104        # Verify no exception should be raised when rmtree raises
105        # EnvironmentError with errno.ENOENT, i.e.
106        # directory no longer exists.
107        _Call()
108        tempfile.mkdtemp.assert_called_once()  #pylint: disable=no-member
109        shutil.rmtree.assert_called_with("/tmp/tempdir")  #pylint: disable=no-member
110
111    def testTempDirWhenDeleteEncounterError(self):
112        """Test create a temp dir and encoutered error during deletion."""
113        self.Patch(os, "chmod")
114        self.Patch(tempfile, "mkdtemp", return_value="/tmp/tempdir")
115        expected_error = OSError("Expected OS Error")
116        self.Patch(shutil, "rmtree", side_effect=expected_error)
117
118        def _Call():
119            with utils.TempDir():
120                pass
121
122        # Verify OSError should be raised.
123        self.assertRaises(OSError, _Call)
124        tempfile.mkdtemp.assert_called_once()  #pylint: disable=no-member
125        shutil.rmtree.assert_called_with("/tmp/tempdir")  #pylint: disable=no-member
126
127    def testTempDirOrininalErrorRaised(self):
128        """Test original error is raised even if tmp dir deletion failed."""
129        self.Patch(os, "chmod")
130        self.Patch(tempfile, "mkdtemp", return_value="/tmp/tempdir")
131        expected_error = OSError("Expected OS Error")
132        self.Patch(shutil, "rmtree", side_effect=expected_error)
133
134        class ExpectedException(Exception):
135            """Expected exception."""
136            pass
137
138        def _Call():
139            with utils.TempDir():
140                raise ExpectedException("Expected Exception")
141
142        # Verify.
143        # ExpectedException should be raised, and OSError
144        # should not be raised.
145        self.assertRaises(ExpectedException, _Call)
146        tempfile.mkdtemp.assert_called_once()  #pylint: disable=no-member
147        shutil.rmtree.assert_called_with("/tmp/tempdir")  #pylint: disable=no-member
148
149    def testCreateSshKeyPairKeyAlreadyExists(self):  #pylint: disable=invalid-name
150        """Test when the key pair already exists."""
151        public_key = "/fake/public_key"
152        private_key = "/fake/private_key"
153        self.Patch(os.path, "exists", side_effect=[True, True])
154        self.Patch(subprocess, "check_call")
155        self.Patch(os, "makedirs", return_value=True)
156        utils.CreateSshKeyPairIfNotExist(private_key, public_key)
157        self.assertEqual(subprocess.check_call.call_count, 0)  #pylint: disable=no-member
158
159    def testCreateSshKeyPairKeyAreCreated(self):
160        """Test when the key pair created."""
161        public_key = "/fake/public_key"
162        private_key = "/fake/private_key"
163        self.Patch(os.path, "exists", return_value=False)
164        self.Patch(os, "makedirs", return_value=True)
165        self.Patch(subprocess, "check_call")
166        self.Patch(os, "rename")
167        utils.CreateSshKeyPairIfNotExist(private_key, public_key)
168        self.assertEqual(subprocess.check_call.call_count, 1)  #pylint: disable=no-member
169        subprocess.check_call.assert_called_with(  #pylint: disable=no-member
170            utils.SSH_KEYGEN_CMD +
171            ["-C", getpass.getuser(), "-f", private_key],
172            stdout=mock.ANY,
173            stderr=mock.ANY)
174
175    def testCreatePublicKeyAreCreated(self):
176        """Test when the PublicKey created."""
177        public_key = "/fake/public_key"
178        private_key = "/fake/private_key"
179        self.Patch(os.path, "exists", side_effect=[False, True, True])
180        self.Patch(os, "makedirs", return_value=True)
181        mock_open = mock.mock_open(read_data=public_key)
182        self.Patch(subprocess, "check_output")
183        self.Patch(os, "rename")
184        with mock.patch("__builtin__.open", mock_open):
185            utils.CreateSshKeyPairIfNotExist(private_key, public_key)
186        self.assertEqual(subprocess.check_output.call_count, 1)  #pylint: disable=no-member
187        subprocess.check_output.assert_called_with(  #pylint: disable=no-member
188            utils.SSH_KEYGEN_PUB_CMD +["-f", private_key])
189
190    def TestRetryOnException(self):
191        """Test Retry."""
192
193        def _IsValueError(exc):
194            return isinstance(exc, ValueError)
195
196        num_retry = 5
197
198        @utils.RetryOnException(_IsValueError, num_retry)
199        def _RaiseAndRetry(sentinel):
200            sentinel.alert()
201            raise ValueError("Fake error.")
202
203        sentinel = mock.MagicMock()
204        self.assertRaises(ValueError, _RaiseAndRetry, sentinel)
205        self.assertEqual(1 + num_retry, sentinel.alert.call_count)
206
207    def testRetryExceptionType(self):
208        """Test RetryExceptionType function."""
209
210        def _RaiseAndRetry(sentinel):
211            sentinel.alert()
212            raise ValueError("Fake error.")
213
214        num_retry = 5
215        sentinel = mock.MagicMock()
216        self.assertRaises(
217            ValueError,
218            utils.RetryExceptionType, (KeyError, ValueError),
219            num_retry,
220            _RaiseAndRetry,
221            0, # sleep_multiplier
222            1, # retry_backoff_factor
223            sentinel=sentinel)
224        self.assertEqual(1 + num_retry, sentinel.alert.call_count)
225
226    def testRetry(self):
227        """Test Retry."""
228        mock_sleep = self.Patch(time, "sleep")
229
230        def _RaiseAndRetry(sentinel):
231            sentinel.alert()
232            raise ValueError("Fake error.")
233
234        num_retry = 5
235        sentinel = mock.MagicMock()
236        self.assertRaises(
237            ValueError,
238            utils.RetryExceptionType, (ValueError, KeyError),
239            num_retry,
240            _RaiseAndRetry,
241            1, # sleep_multiplier
242            2, # retry_backoff_factor
243            sentinel=sentinel)
244
245        self.assertEqual(1 + num_retry, sentinel.alert.call_count)
246        mock_sleep.assert_has_calls(
247            [
248                mock.call(1),
249                mock.call(2),
250                mock.call(4),
251                mock.call(8),
252                mock.call(16)
253            ])
254
255    @mock.patch("__builtin__.raw_input")
256    def testGetAnswerFromList(self, mock_raw_input):
257        """Test GetAnswerFromList."""
258        answer_list = ["image1.zip", "image2.zip", "image3.zip"]
259        mock_raw_input.return_value = 0
260        with self.assertRaises(SystemExit):
261            utils.GetAnswerFromList(answer_list)
262        mock_raw_input.side_effect = [1, 2, 3, 4]
263        self.assertEqual(utils.GetAnswerFromList(answer_list),
264                         ["image1.zip"])
265        self.assertEqual(utils.GetAnswerFromList(answer_list),
266                         ["image2.zip"])
267        self.assertEqual(utils.GetAnswerFromList(answer_list),
268                         ["image3.zip"])
269        self.assertEqual(utils.GetAnswerFromList(answer_list,
270                                                 enable_choose_all=True),
271                         answer_list)
272
273    @unittest.skipIf(isinstance(Tkinter, mock.Mock), "Tkinter mocked out, test case not needed.")
274    @mock.patch.object(Tkinter, "Tk")
275    def testCalculateVNCScreenRatio(self, mock_tk):
276        """Test Calculating the scale ratio of VNC display."""
277        # Get scale-down ratio if screen height is smaller than AVD height.
278        mock_tk.return_value = FakeTkinter(height=800, width=1200)
279        avd_h = 1920
280        avd_w = 1080
281        self.assertEqual(utils.CalculateVNCScreenRatio(avd_w, avd_h), 0.4)
282
283        # Get scale-down ratio if screen width is smaller than AVD width.
284        mock_tk.return_value = FakeTkinter(height=800, width=1200)
285        avd_h = 900
286        avd_w = 1920
287        self.assertEqual(utils.CalculateVNCScreenRatio(avd_w, avd_h), 0.6)
288
289        # Scale ratio = 1 if screen is larger than AVD.
290        mock_tk.return_value = FakeTkinter(height=1080, width=1920)
291        avd_h = 800
292        avd_w = 1280
293        self.assertEqual(utils.CalculateVNCScreenRatio(avd_w, avd_h), 1)
294
295        # Get the scale if ratio of width is smaller than the
296        # ratio of height.
297        mock_tk.return_value = FakeTkinter(height=1200, width=800)
298        avd_h = 1920
299        avd_w = 1080
300        self.assertEqual(utils.CalculateVNCScreenRatio(avd_w, avd_h), 0.6)
301
302    # pylint: disable=protected-access
303    def testCheckUserInGroups(self):
304        """Test CheckUserInGroups."""
305        self.Patch(os, "getgroups", return_value=[1, 2, 3])
306        gr1 = mock.MagicMock()
307        gr1.gr_name = "fake_gr_1"
308        gr2 = mock.MagicMock()
309        gr2.gr_name = "fake_gr_2"
310        gr3 = mock.MagicMock()
311        gr3.gr_name = "fake_gr_3"
312        self.Patch(grp, "getgrgid", side_effect=[gr1, gr2, gr3])
313
314        # User in all required groups should return true.
315        self.assertTrue(
316            utils.CheckUserInGroups(
317                ["fake_gr_1", "fake_gr_2"]))
318
319        # User not in all required groups should return False.
320        self.Patch(grp, "getgrgid", side_effect=[gr1, gr2, gr3])
321        self.assertFalse(
322            utils.CheckUserInGroups(
323                ["fake_gr_1", "fake_gr_4"]))
324
325    @mock.patch.object(utils, "CheckUserInGroups")
326    def testAddUserGroupsToCmd(self, mock_user_group):
327        """Test AddUserGroupsToCmd."""
328        command = "test_command"
329        groups = ["group1", "group2"]
330        # Don't add user group in command
331        mock_user_group.return_value = True
332        expected_value = "test_command"
333        self.assertEqual(expected_value, utils.AddUserGroupsToCmd(command,
334                                                                  groups))
335
336        # Add user group in command
337        mock_user_group.return_value = False
338        expected_value = "sg group1 <<EOF\nsg group2\ntest_command\nEOF"
339        self.assertEqual(expected_value, utils.AddUserGroupsToCmd(command,
340                                                                  groups))
341
342    @staticmethod
343    def testScpPullFileSuccess():
344        """Test scp pull file successfully."""
345        subprocess.check_call = mock.MagicMock()
346        utils.ScpPullFile("/tmp/test", "/tmp/test_1.log", "192.168.0.1")
347        subprocess.check_call.assert_called_with(utils.SCP_CMD + [
348            "192.168.0.1:/tmp/test", "/tmp/test_1.log"])
349
350    @staticmethod
351    def testScpPullFileWithUserNameSuccess():
352        """Test scp pull file successfully."""
353        subprocess.check_call = mock.MagicMock()
354        utils.ScpPullFile("/tmp/test", "/tmp/test_1.log", "192.168.0.1",
355                          user_name="abc")
356        subprocess.check_call.assert_called_with(utils.SCP_CMD + [
357            "abc@192.168.0.1:/tmp/test", "/tmp/test_1.log"])
358
359    # pylint: disable=invalid-name
360    @staticmethod
361    def testScpPullFileWithUserNameWithRsaKeySuccess():
362        """Test scp pull file successfully."""
363        subprocess.check_call = mock.MagicMock()
364        utils.ScpPullFile("/tmp/test", "/tmp/test_1.log", "192.168.0.1",
365                          user_name="abc", rsa_key_file="/tmp/my_key")
366        subprocess.check_call.assert_called_with(utils.SCP_CMD + [
367            "-i", "/tmp/my_key", "abc@192.168.0.1:/tmp/test",
368            "/tmp/test_1.log"])
369
370    def testScpPullFileScpFailure(self):
371        """Test scp pull file failure."""
372        subprocess.check_call = mock.MagicMock(
373            side_effect=subprocess.CalledProcessError(123, "fake",
374                                                      "fake error"))
375        self.assertRaises(
376            errors.DeviceConnectionError,
377            utils.ScpPullFile, "/tmp/test", "/tmp/test_1.log", "192.168.0.1")
378
379
380    def testTimeoutException(self):
381        """Test TimeoutException."""
382        @utils.TimeoutException(1, "should time out")
383        def functionThatWillTimeOut():
384            """Test decorator of @utils.TimeoutException should timeout."""
385            time.sleep(5)
386
387        self.assertRaises(errors.FunctionTimeoutError,
388                          functionThatWillTimeOut)
389
390
391    def testTimeoutExceptionNoTimeout(self):
392        """Test No TimeoutException."""
393        @utils.TimeoutException(5, "shouldn't time out")
394        def functionThatShouldNotTimeout():
395            """Test decorator of @utils.TimeoutException shouldn't timeout."""
396            return None
397        try:
398            functionThatShouldNotTimeout()
399        except errors.FunctionTimeoutError:
400            self.fail("shouldn't timeout")
401
402    def testAutoConnectCreateSSHTunnelFail(self):
403        """test auto connect."""
404        fake_ip_addr = "1.1.1.1"
405        fake_rsa_key_file = "/tmp/rsa_file"
406        fake_target_vnc_port = 8888
407        target_adb_port = 9999
408        ssh_user = "fake_user"
409        call_side_effect = subprocess.CalledProcessError(123, "fake",
410                                                         "fake error")
411        result = utils.ForwardedPorts(vnc_port=None, adb_port=None)
412        self.Patch(subprocess, "check_call", side_effect=call_side_effect)
413        self.assertEqual(result, utils.AutoConnect(fake_ip_addr,
414                                                   fake_rsa_key_file,
415                                                   fake_target_vnc_port,
416                                                   target_adb_port,
417                                                   ssh_user))
418
419
420if __name__ == "__main__":
421    unittest.main()
422