1#!/usr/bin/env python 2# 3# Copyright 2016 - The Android Open Source Project 4# 5# Licensed under the Apache License, Version 2.0 (the "License"); 6# you may not use this file except in compliance with the License. 7# You may obtain a copy of the License at 8# 9# http://www.apache.org/licenses/LICENSE-2.0 10# 11# Unless required by applicable law or agreed to in writing, software 12# distributed under the License is distributed on an "AS IS" BASIS, 13# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14# See the License for the specific language governing permissions and 15# limitations under the License. 16"""Tests for acloud.internal.lib.utils.""" 17 18import errno 19import getpass 20import grp 21import os 22import shutil 23import subprocess 24import tempfile 25import time 26 27import unittest 28import mock 29 30from acloud import errors 31from acloud.internal.lib import driver_test_lib 32from acloud.internal.lib import utils 33 34# Tkinter may not be supported so mock it out. 35try: 36 import Tkinter 37except ImportError: 38 Tkinter = mock.Mock() 39 40class FakeTkinter(object): 41 """Fake implementation of Tkinter.Tk()""" 42 43 def __init__(self, width=None, height=None): 44 self.width = width 45 self.height = height 46 47 # pylint: disable=invalid-name 48 def winfo_screenheight(self): 49 """Return the screen height.""" 50 return self.height 51 52 # pylint: disable=invalid-name 53 def winfo_screenwidth(self): 54 """Return the screen width.""" 55 return self.width 56 57 58# pylint: disable=too-many-public-methods 59class UtilsTest(driver_test_lib.BaseDriverTest): 60 """Test Utils.""" 61 62 def TestTempDirSuccess(self): 63 """Test create a temp dir.""" 64 self.Patch(os, "chmod") 65 self.Patch(tempfile, "mkdtemp", return_value="/tmp/tempdir") 66 self.Patch(shutil, "rmtree") 67 with utils.TempDir(): 68 pass 69 # Verify. 70 tempfile.mkdtemp.assert_called_once() # pylint: disable=no-member 71 shutil.rmtree.assert_called_with("/tmp/tempdir") # pylint: disable=no-member 72 73 def TestTempDirExceptionRaised(self): 74 """Test create a temp dir and exception is raised within with-clause.""" 75 self.Patch(os, "chmod") 76 self.Patch(tempfile, "mkdtemp", return_value="/tmp/tempdir") 77 self.Patch(shutil, "rmtree") 78 79 class ExpectedException(Exception): 80 """Expected exception.""" 81 pass 82 83 def _Call(): 84 with utils.TempDir(): 85 raise ExpectedException("Expected exception.") 86 87 # Verify. ExpectedException should be raised. 88 self.assertRaises(ExpectedException, _Call) 89 tempfile.mkdtemp.assert_called_once() # pylint: disable=no-member 90 shutil.rmtree.assert_called_with("/tmp/tempdir") #pylint: disable=no-member 91 92 def testTempDirWhenDeleteTempDirNoLongerExist(self): # pylint: disable=invalid-name 93 """Test create a temp dir and dir no longer exists during deletion.""" 94 self.Patch(os, "chmod") 95 self.Patch(tempfile, "mkdtemp", return_value="/tmp/tempdir") 96 expected_error = EnvironmentError() 97 expected_error.errno = errno.ENOENT 98 self.Patch(shutil, "rmtree", side_effect=expected_error) 99 100 def _Call(): 101 with utils.TempDir(): 102 pass 103 104 # Verify no exception should be raised when rmtree raises 105 # EnvironmentError with errno.ENOENT, i.e. 106 # directory no longer exists. 107 _Call() 108 tempfile.mkdtemp.assert_called_once() #pylint: disable=no-member 109 shutil.rmtree.assert_called_with("/tmp/tempdir") #pylint: disable=no-member 110 111 def testTempDirWhenDeleteEncounterError(self): 112 """Test create a temp dir and encoutered error during deletion.""" 113 self.Patch(os, "chmod") 114 self.Patch(tempfile, "mkdtemp", return_value="/tmp/tempdir") 115 expected_error = OSError("Expected OS Error") 116 self.Patch(shutil, "rmtree", side_effect=expected_error) 117 118 def _Call(): 119 with utils.TempDir(): 120 pass 121 122 # Verify OSError should be raised. 123 self.assertRaises(OSError, _Call) 124 tempfile.mkdtemp.assert_called_once() #pylint: disable=no-member 125 shutil.rmtree.assert_called_with("/tmp/tempdir") #pylint: disable=no-member 126 127 def testTempDirOrininalErrorRaised(self): 128 """Test original error is raised even if tmp dir deletion failed.""" 129 self.Patch(os, "chmod") 130 self.Patch(tempfile, "mkdtemp", return_value="/tmp/tempdir") 131 expected_error = OSError("Expected OS Error") 132 self.Patch(shutil, "rmtree", side_effect=expected_error) 133 134 class ExpectedException(Exception): 135 """Expected exception.""" 136 pass 137 138 def _Call(): 139 with utils.TempDir(): 140 raise ExpectedException("Expected Exception") 141 142 # Verify. 143 # ExpectedException should be raised, and OSError 144 # should not be raised. 145 self.assertRaises(ExpectedException, _Call) 146 tempfile.mkdtemp.assert_called_once() #pylint: disable=no-member 147 shutil.rmtree.assert_called_with("/tmp/tempdir") #pylint: disable=no-member 148 149 def testCreateSshKeyPairKeyAlreadyExists(self): #pylint: disable=invalid-name 150 """Test when the key pair already exists.""" 151 public_key = "/fake/public_key" 152 private_key = "/fake/private_key" 153 self.Patch(os.path, "exists", side_effect=[True, True]) 154 self.Patch(subprocess, "check_call") 155 self.Patch(os, "makedirs", return_value=True) 156 utils.CreateSshKeyPairIfNotExist(private_key, public_key) 157 self.assertEqual(subprocess.check_call.call_count, 0) #pylint: disable=no-member 158 159 def testCreateSshKeyPairKeyAreCreated(self): 160 """Test when the key pair created.""" 161 public_key = "/fake/public_key" 162 private_key = "/fake/private_key" 163 self.Patch(os.path, "exists", return_value=False) 164 self.Patch(os, "makedirs", return_value=True) 165 self.Patch(subprocess, "check_call") 166 self.Patch(os, "rename") 167 utils.CreateSshKeyPairIfNotExist(private_key, public_key) 168 self.assertEqual(subprocess.check_call.call_count, 1) #pylint: disable=no-member 169 subprocess.check_call.assert_called_with( #pylint: disable=no-member 170 utils.SSH_KEYGEN_CMD + 171 ["-C", getpass.getuser(), "-f", private_key], 172 stdout=mock.ANY, 173 stderr=mock.ANY) 174 175 def testCreatePublicKeyAreCreated(self): 176 """Test when the PublicKey created.""" 177 public_key = "/fake/public_key" 178 private_key = "/fake/private_key" 179 self.Patch(os.path, "exists", side_effect=[False, True, True]) 180 self.Patch(os, "makedirs", return_value=True) 181 mock_open = mock.mock_open(read_data=public_key) 182 self.Patch(subprocess, "check_output") 183 self.Patch(os, "rename") 184 with mock.patch("__builtin__.open", mock_open): 185 utils.CreateSshKeyPairIfNotExist(private_key, public_key) 186 self.assertEqual(subprocess.check_output.call_count, 1) #pylint: disable=no-member 187 subprocess.check_output.assert_called_with( #pylint: disable=no-member 188 utils.SSH_KEYGEN_PUB_CMD +["-f", private_key]) 189 190 def TestRetryOnException(self): 191 """Test Retry.""" 192 193 def _IsValueError(exc): 194 return isinstance(exc, ValueError) 195 196 num_retry = 5 197 198 @utils.RetryOnException(_IsValueError, num_retry) 199 def _RaiseAndRetry(sentinel): 200 sentinel.alert() 201 raise ValueError("Fake error.") 202 203 sentinel = mock.MagicMock() 204 self.assertRaises(ValueError, _RaiseAndRetry, sentinel) 205 self.assertEqual(1 + num_retry, sentinel.alert.call_count) 206 207 def testRetryExceptionType(self): 208 """Test RetryExceptionType function.""" 209 210 def _RaiseAndRetry(sentinel): 211 sentinel.alert() 212 raise ValueError("Fake error.") 213 214 num_retry = 5 215 sentinel = mock.MagicMock() 216 self.assertRaises( 217 ValueError, 218 utils.RetryExceptionType, (KeyError, ValueError), 219 num_retry, 220 _RaiseAndRetry, 221 0, # sleep_multiplier 222 1, # retry_backoff_factor 223 sentinel=sentinel) 224 self.assertEqual(1 + num_retry, sentinel.alert.call_count) 225 226 def testRetry(self): 227 """Test Retry.""" 228 mock_sleep = self.Patch(time, "sleep") 229 230 def _RaiseAndRetry(sentinel): 231 sentinel.alert() 232 raise ValueError("Fake error.") 233 234 num_retry = 5 235 sentinel = mock.MagicMock() 236 self.assertRaises( 237 ValueError, 238 utils.RetryExceptionType, (ValueError, KeyError), 239 num_retry, 240 _RaiseAndRetry, 241 1, # sleep_multiplier 242 2, # retry_backoff_factor 243 sentinel=sentinel) 244 245 self.assertEqual(1 + num_retry, sentinel.alert.call_count) 246 mock_sleep.assert_has_calls( 247 [ 248 mock.call(1), 249 mock.call(2), 250 mock.call(4), 251 mock.call(8), 252 mock.call(16) 253 ]) 254 255 @mock.patch("__builtin__.raw_input") 256 def testGetAnswerFromList(self, mock_raw_input): 257 """Test GetAnswerFromList.""" 258 answer_list = ["image1.zip", "image2.zip", "image3.zip"] 259 mock_raw_input.return_value = 0 260 with self.assertRaises(SystemExit): 261 utils.GetAnswerFromList(answer_list) 262 mock_raw_input.side_effect = [1, 2, 3, 4] 263 self.assertEqual(utils.GetAnswerFromList(answer_list), 264 ["image1.zip"]) 265 self.assertEqual(utils.GetAnswerFromList(answer_list), 266 ["image2.zip"]) 267 self.assertEqual(utils.GetAnswerFromList(answer_list), 268 ["image3.zip"]) 269 self.assertEqual(utils.GetAnswerFromList(answer_list, 270 enable_choose_all=True), 271 answer_list) 272 273 @unittest.skipIf(isinstance(Tkinter, mock.Mock), "Tkinter mocked out, test case not needed.") 274 @mock.patch.object(Tkinter, "Tk") 275 def testCalculateVNCScreenRatio(self, mock_tk): 276 """Test Calculating the scale ratio of VNC display.""" 277 # Get scale-down ratio if screen height is smaller than AVD height. 278 mock_tk.return_value = FakeTkinter(height=800, width=1200) 279 avd_h = 1920 280 avd_w = 1080 281 self.assertEqual(utils.CalculateVNCScreenRatio(avd_w, avd_h), 0.4) 282 283 # Get scale-down ratio if screen width is smaller than AVD width. 284 mock_tk.return_value = FakeTkinter(height=800, width=1200) 285 avd_h = 900 286 avd_w = 1920 287 self.assertEqual(utils.CalculateVNCScreenRatio(avd_w, avd_h), 0.6) 288 289 # Scale ratio = 1 if screen is larger than AVD. 290 mock_tk.return_value = FakeTkinter(height=1080, width=1920) 291 avd_h = 800 292 avd_w = 1280 293 self.assertEqual(utils.CalculateVNCScreenRatio(avd_w, avd_h), 1) 294 295 # Get the scale if ratio of width is smaller than the 296 # ratio of height. 297 mock_tk.return_value = FakeTkinter(height=1200, width=800) 298 avd_h = 1920 299 avd_w = 1080 300 self.assertEqual(utils.CalculateVNCScreenRatio(avd_w, avd_h), 0.6) 301 302 # pylint: disable=protected-access 303 def testCheckUserInGroups(self): 304 """Test CheckUserInGroups.""" 305 self.Patch(os, "getgroups", return_value=[1, 2, 3]) 306 gr1 = mock.MagicMock() 307 gr1.gr_name = "fake_gr_1" 308 gr2 = mock.MagicMock() 309 gr2.gr_name = "fake_gr_2" 310 gr3 = mock.MagicMock() 311 gr3.gr_name = "fake_gr_3" 312 self.Patch(grp, "getgrgid", side_effect=[gr1, gr2, gr3]) 313 314 # User in all required groups should return true. 315 self.assertTrue( 316 utils.CheckUserInGroups( 317 ["fake_gr_1", "fake_gr_2"])) 318 319 # User not in all required groups should return False. 320 self.Patch(grp, "getgrgid", side_effect=[gr1, gr2, gr3]) 321 self.assertFalse( 322 utils.CheckUserInGroups( 323 ["fake_gr_1", "fake_gr_4"])) 324 325 @mock.patch.object(utils, "CheckUserInGroups") 326 def testAddUserGroupsToCmd(self, mock_user_group): 327 """Test AddUserGroupsToCmd.""" 328 command = "test_command" 329 groups = ["group1", "group2"] 330 # Don't add user group in command 331 mock_user_group.return_value = True 332 expected_value = "test_command" 333 self.assertEqual(expected_value, utils.AddUserGroupsToCmd(command, 334 groups)) 335 336 # Add user group in command 337 mock_user_group.return_value = False 338 expected_value = "sg group1 <<EOF\nsg group2\ntest_command\nEOF" 339 self.assertEqual(expected_value, utils.AddUserGroupsToCmd(command, 340 groups)) 341 342 @staticmethod 343 def testScpPullFileSuccess(): 344 """Test scp pull file successfully.""" 345 subprocess.check_call = mock.MagicMock() 346 utils.ScpPullFile("/tmp/test", "/tmp/test_1.log", "192.168.0.1") 347 subprocess.check_call.assert_called_with(utils.SCP_CMD + [ 348 "192.168.0.1:/tmp/test", "/tmp/test_1.log"]) 349 350 @staticmethod 351 def testScpPullFileWithUserNameSuccess(): 352 """Test scp pull file successfully.""" 353 subprocess.check_call = mock.MagicMock() 354 utils.ScpPullFile("/tmp/test", "/tmp/test_1.log", "192.168.0.1", 355 user_name="abc") 356 subprocess.check_call.assert_called_with(utils.SCP_CMD + [ 357 "abc@192.168.0.1:/tmp/test", "/tmp/test_1.log"]) 358 359 # pylint: disable=invalid-name 360 @staticmethod 361 def testScpPullFileWithUserNameWithRsaKeySuccess(): 362 """Test scp pull file successfully.""" 363 subprocess.check_call = mock.MagicMock() 364 utils.ScpPullFile("/tmp/test", "/tmp/test_1.log", "192.168.0.1", 365 user_name="abc", rsa_key_file="/tmp/my_key") 366 subprocess.check_call.assert_called_with(utils.SCP_CMD + [ 367 "-i", "/tmp/my_key", "abc@192.168.0.1:/tmp/test", 368 "/tmp/test_1.log"]) 369 370 def testScpPullFileScpFailure(self): 371 """Test scp pull file failure.""" 372 subprocess.check_call = mock.MagicMock( 373 side_effect=subprocess.CalledProcessError(123, "fake", 374 "fake error")) 375 self.assertRaises( 376 errors.DeviceConnectionError, 377 utils.ScpPullFile, "/tmp/test", "/tmp/test_1.log", "192.168.0.1") 378 379 380 def testTimeoutException(self): 381 """Test TimeoutException.""" 382 @utils.TimeoutException(1, "should time out") 383 def functionThatWillTimeOut(): 384 """Test decorator of @utils.TimeoutException should timeout.""" 385 time.sleep(5) 386 387 self.assertRaises(errors.FunctionTimeoutError, 388 functionThatWillTimeOut) 389 390 391 def testTimeoutExceptionNoTimeout(self): 392 """Test No TimeoutException.""" 393 @utils.TimeoutException(5, "shouldn't time out") 394 def functionThatShouldNotTimeout(): 395 """Test decorator of @utils.TimeoutException shouldn't timeout.""" 396 return None 397 try: 398 functionThatShouldNotTimeout() 399 except errors.FunctionTimeoutError: 400 self.fail("shouldn't timeout") 401 402 def testAutoConnectCreateSSHTunnelFail(self): 403 """test auto connect.""" 404 fake_ip_addr = "1.1.1.1" 405 fake_rsa_key_file = "/tmp/rsa_file" 406 fake_target_vnc_port = 8888 407 target_adb_port = 9999 408 ssh_user = "fake_user" 409 call_side_effect = subprocess.CalledProcessError(123, "fake", 410 "fake error") 411 result = utils.ForwardedPorts(vnc_port=None, adb_port=None) 412 self.Patch(subprocess, "check_call", side_effect=call_side_effect) 413 self.assertEqual(result, utils.AutoConnect(fake_ip_addr, 414 fake_rsa_key_file, 415 fake_target_vnc_port, 416 target_adb_port, 417 ssh_user)) 418 419 420if __name__ == "__main__": 421 unittest.main() 422