1#!/usr/bin/env python 2# 3# Copyright 2016 - The Android Open Source Project 4# 5# Licensed under the Apache License, Version 2.0 (the "License"); 6# you may not use this file except in compliance with the License. 7# You may obtain a copy of the License at 8# 9# http://www.apache.org/licenses/LICENSE-2.0 10# 11# Unless required by applicable law or agreed to in writing, software 12# distributed under the License is distributed on an "AS IS" BASIS, 13# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14# See the License for the specific language governing permissions and 15# limitations under the License. 16"""Tests for acloud.internal.lib.utils.""" 17 18import collections 19import errno 20import getpass 21import grp 22import os 23import shutil 24import subprocess 25import tempfile 26import time 27import webbrowser 28 29import unittest 30 31from unittest import mock 32import six 33 34from acloud import errors 35from acloud.internal.lib import driver_test_lib 36from acloud.internal.lib import utils 37 38 39GroupInfo = collections.namedtuple("GroupInfo", [ 40 "gr_name", 41 "gr_passwd", 42 "gr_gid", 43 "gr_mem"]) 44 45# Tkinter may not be supported so mock it out. 46try: 47 import Tkinter 48except ImportError: 49 Tkinter = mock.Mock() 50 51 52class FakeTkinter: 53 """Fake implementation of Tkinter.Tk()""" 54 55 def __init__(self, width=None, height=None): 56 self.width = width 57 self.height = height 58 59 # pylint: disable=invalid-name 60 def winfo_screenheight(self): 61 """Return the screen height.""" 62 return self.height 63 64 # pylint: disable=invalid-name 65 def winfo_screenwidth(self): 66 """Return the screen width.""" 67 return self.width 68 69 70# pylint: disable=too-many-public-methods 71class UtilsTest(driver_test_lib.BaseDriverTest): 72 """Test Utils.""" 73 74 def TestTempDirSuccess(self): 75 """Test create a temp dir.""" 76 self.Patch(os, "chmod") 77 self.Patch(tempfile, "mkdtemp", return_value="/tmp/tempdir") 78 self.Patch(shutil, "rmtree") 79 with utils.TempDir(): 80 pass 81 # Verify. 82 tempfile.mkdtemp.assert_called_once() # pylint: disable=no-member 83 shutil.rmtree.assert_called_with("/tmp/tempdir") # pylint: disable=no-member 84 85 def TestTempDirExceptionRaised(self): 86 """Test create a temp dir and exception is raised within with-clause.""" 87 self.Patch(os, "chmod") 88 self.Patch(tempfile, "mkdtemp", return_value="/tmp/tempdir") 89 self.Patch(shutil, "rmtree") 90 91 class ExpectedException(Exception): 92 """Expected exception.""" 93 94 def _Call(): 95 with utils.TempDir(): 96 raise ExpectedException("Expected exception.") 97 98 # Verify. ExpectedException should be raised. 99 self.assertRaises(ExpectedException, _Call) 100 tempfile.mkdtemp.assert_called_once() # pylint: disable=no-member 101 shutil.rmtree.assert_called_with("/tmp/tempdir") #pylint: disable=no-member 102 103 def testTempDirWhenDeleteTempDirNoLongerExist(self): # pylint: disable=invalid-name 104 """Test create a temp dir and dir no longer exists during deletion.""" 105 self.Patch(os, "chmod") 106 self.Patch(tempfile, "mkdtemp", return_value="/tmp/tempdir") 107 expected_error = EnvironmentError() 108 expected_error.errno = errno.ENOENT 109 self.Patch(shutil, "rmtree", side_effect=expected_error) 110 111 def _Call(): 112 with utils.TempDir(): 113 pass 114 115 # Verify no exception should be raised when rmtree raises 116 # EnvironmentError with errno.ENOENT, i.e. 117 # directory no longer exists. 118 _Call() 119 tempfile.mkdtemp.assert_called_once() #pylint: disable=no-member 120 shutil.rmtree.assert_called_with("/tmp/tempdir") #pylint: disable=no-member 121 122 def testTempDirWhenDeleteEncounterError(self): 123 """Test create a temp dir and encoutered error during deletion.""" 124 self.Patch(os, "chmod") 125 self.Patch(tempfile, "mkdtemp", return_value="/tmp/tempdir") 126 expected_error = OSError("Expected OS Error") 127 self.Patch(shutil, "rmtree", side_effect=expected_error) 128 129 def _Call(): 130 with utils.TempDir(): 131 pass 132 133 # Verify OSError should be raised. 134 self.assertRaises(OSError, _Call) 135 tempfile.mkdtemp.assert_called_once() #pylint: disable=no-member 136 shutil.rmtree.assert_called_with("/tmp/tempdir") #pylint: disable=no-member 137 138 def testTempDirOrininalErrorRaised(self): 139 """Test original error is raised even if tmp dir deletion failed.""" 140 self.Patch(os, "chmod") 141 self.Patch(tempfile, "mkdtemp", return_value="/tmp/tempdir") 142 expected_error = OSError("Expected OS Error") 143 self.Patch(shutil, "rmtree", side_effect=expected_error) 144 145 class ExpectedException(Exception): 146 """Expected exception.""" 147 148 def _Call(): 149 with utils.TempDir(): 150 raise ExpectedException("Expected Exception") 151 152 # Verify. 153 # ExpectedException should be raised, and OSError 154 # should not be raised. 155 self.assertRaises(ExpectedException, _Call) 156 tempfile.mkdtemp.assert_called_once() #pylint: disable=no-member 157 shutil.rmtree.assert_called_with("/tmp/tempdir") #pylint: disable=no-member 158 159 def testCreateSshKeyPairKeyAlreadyExists(self): #pylint: disable=invalid-name 160 """Test when the key pair already exists.""" 161 public_key = "/fake/public_key" 162 private_key = "/fake/private_key" 163 self.Patch(os.path, "exists", side_effect=[True, True]) 164 self.Patch(subprocess, "check_call") 165 self.Patch(os, "makedirs", return_value=True) 166 utils.CreateSshKeyPairIfNotExist(private_key, public_key) 167 self.assertEqual(subprocess.check_call.call_count, 0) #pylint: disable=no-member 168 169 def testCreateSshKeyPairKeyAreCreated(self): 170 """Test when the key pair created.""" 171 public_key = "/fake/public_key" 172 private_key = "/fake/private_key" 173 self.Patch(os.path, "exists", return_value=False) 174 self.Patch(os, "makedirs", return_value=True) 175 self.Patch(subprocess, "check_call") 176 self.Patch(os, "rename") 177 utils.CreateSshKeyPairIfNotExist(private_key, public_key) 178 self.assertEqual(subprocess.check_call.call_count, 1) #pylint: disable=no-member 179 subprocess.check_call.assert_called_with( #pylint: disable=no-member 180 utils.SSH_KEYGEN_CMD + 181 ["-C", getpass.getuser(), "-f", private_key], 182 stdout=mock.ANY, 183 stderr=mock.ANY) 184 185 def testCreatePublicKeyAreCreated(self): 186 """Test when the PublicKey created.""" 187 public_key = "/fake/public_key" 188 private_key = "/fake/private_key" 189 self.Patch(os.path, "exists", side_effect=[False, True, True]) 190 self.Patch(os, "makedirs", return_value=True) 191 mock_open = mock.mock_open(read_data=public_key) 192 self.Patch(subprocess, "check_output") 193 self.Patch(os, "rename") 194 with mock.patch.object(six.moves.builtins, "open", mock_open): 195 utils.CreateSshKeyPairIfNotExist(private_key, public_key) 196 self.assertEqual(subprocess.check_output.call_count, 1) #pylint: disable=no-member 197 subprocess.check_output.assert_called_with( #pylint: disable=no-member 198 utils.SSH_KEYGEN_PUB_CMD +["-f", private_key]) 199 200 def TestRetryOnException(self): 201 """Test Retry.""" 202 203 def _IsValueError(exc): 204 return isinstance(exc, ValueError) 205 206 num_retry = 5 207 208 @utils.RetryOnException(_IsValueError, num_retry) 209 def _RaiseAndRetry(sentinel): 210 sentinel.alert() 211 raise ValueError("Fake error.") 212 213 sentinel = mock.MagicMock() 214 self.assertRaises(ValueError, _RaiseAndRetry, sentinel) 215 self.assertEqual(1 + num_retry, sentinel.alert.call_count) 216 217 def testRetryExceptionType(self): 218 """Test RetryExceptionType function.""" 219 220 def _RaiseAndRetry(sentinel): 221 sentinel.alert() 222 raise ValueError("Fake error.") 223 224 num_retry = 5 225 sentinel = mock.MagicMock() 226 self.assertRaises( 227 ValueError, 228 utils.RetryExceptionType, (KeyError, ValueError), 229 num_retry, 230 _RaiseAndRetry, 231 0, # sleep_multiplier 232 1, # retry_backoff_factor 233 sentinel=sentinel) 234 self.assertEqual(1 + num_retry, sentinel.alert.call_count) 235 236 def testRetry(self): 237 """Test Retry.""" 238 mock_sleep = self.Patch(time, "sleep") 239 240 def _RaiseAndRetry(sentinel): 241 sentinel.alert() 242 raise ValueError("Fake error.") 243 244 num_retry = 5 245 sentinel = mock.MagicMock() 246 self.assertRaises( 247 ValueError, 248 utils.RetryExceptionType, (ValueError, KeyError), 249 num_retry, 250 _RaiseAndRetry, 251 1, # sleep_multiplier 252 2, # retry_backoff_factor 253 sentinel=sentinel) 254 255 self.assertEqual(1 + num_retry, sentinel.alert.call_count) 256 mock_sleep.assert_has_calls( 257 [ 258 mock.call(1), 259 mock.call(2), 260 mock.call(4), 261 mock.call(8), 262 mock.call(16) 263 ]) 264 265 @mock.patch.object(six.moves, "input") 266 def testGetAnswerFromList(self, mock_raw_input): 267 """Test GetAnswerFromList.""" 268 answer_list = ["image1.zip", "image2.zip", "image3.zip"] 269 mock_raw_input.return_value = 0 270 with self.assertRaises(SystemExit): 271 utils.GetAnswerFromList(answer_list) 272 mock_raw_input.side_effect = [1, 2, 3, 4] 273 self.assertEqual(utils.GetAnswerFromList(answer_list), 274 ["image1.zip"]) 275 self.assertEqual(utils.GetAnswerFromList(answer_list), 276 ["image2.zip"]) 277 self.assertEqual(utils.GetAnswerFromList(answer_list), 278 ["image3.zip"]) 279 self.assertEqual(utils.GetAnswerFromList(answer_list, 280 enable_choose_all=True), 281 answer_list) 282 283 @unittest.skipIf(isinstance(Tkinter, mock.Mock), "Tkinter mocked out, test case not needed.") 284 @mock.patch.object(Tkinter, "Tk") 285 def testCalculateVNCScreenRatio(self, mock_tk): 286 """Test Calculating the scale ratio of VNC display.""" 287 # Get scale-down ratio if screen height is smaller than AVD height. 288 mock_tk.return_value = FakeTkinter(height=800, width=1200) 289 avd_h = 1920 290 avd_w = 1080 291 self.assertEqual(utils.CalculateVNCScreenRatio(avd_w, avd_h), 0.4) 292 293 # Get scale-down ratio if screen width is smaller than AVD width. 294 mock_tk.return_value = FakeTkinter(height=800, width=1200) 295 avd_h = 900 296 avd_w = 1920 297 self.assertEqual(utils.CalculateVNCScreenRatio(avd_w, avd_h), 0.6) 298 299 # Scale ratio = 1 if screen is larger than AVD. 300 mock_tk.return_value = FakeTkinter(height=1080, width=1920) 301 avd_h = 800 302 avd_w = 1280 303 self.assertEqual(utils.CalculateVNCScreenRatio(avd_w, avd_h), 1) 304 305 # Get the scale if ratio of width is smaller than the 306 # ratio of height. 307 mock_tk.return_value = FakeTkinter(height=1200, width=800) 308 avd_h = 1920 309 avd_w = 1080 310 self.assertEqual(utils.CalculateVNCScreenRatio(avd_w, avd_h), 0.6) 311 312 def testCheckUserInGroups(self): 313 """Test CheckUserInGroups.""" 314 self.Patch(getpass, "getuser", return_value="user_0") 315 self.Patch(grp, "getgrall", return_value=[ 316 GroupInfo("fake_group1", "passwd_1", 0, ["user_1", "user_2"]), 317 GroupInfo("fake_group2", "passwd_2", 1, ["user_1", "user_2"])]) 318 self.Patch(grp, "getgrnam", return_value=GroupInfo( 319 "fake_group1", "passwd_1", 0, ["user_1", "user_2"])) 320 # Test Group name doesn't exist. 321 self.assertFalse(utils.CheckUserInGroups(["Non_exist_group"])) 322 323 # Test User isn't in group. 324 self.assertFalse(utils.CheckUserInGroups(["fake_group1"])) 325 326 # Test User is in group. 327 self.Patch(getpass, "getuser", return_value="user_1") 328 self.assertTrue(utils.CheckUserInGroups(["fake_group1"])) 329 330 @mock.patch.object(utils, "CheckUserInGroups") 331 def testAddUserGroupsToCmd(self, mock_user_group): 332 """Test AddUserGroupsToCmd.""" 333 command = "test_command" 334 groups = ["group1", "group2"] 335 # Don't add user group in command 336 mock_user_group.return_value = True 337 expected_value = "test_command" 338 self.assertEqual(expected_value, utils.AddUserGroupsToCmd(command, 339 groups)) 340 341 # Add user group in command 342 mock_user_group.return_value = False 343 expected_value = "sg group1 <<EOF\nsg group2\ntest_command\nEOF" 344 self.assertEqual(expected_value, utils.AddUserGroupsToCmd(command, 345 groups)) 346 347 # pylint: disable=invalid-name 348 def testTimeoutException(self): 349 """Test TimeoutException.""" 350 @utils.TimeoutException(1, "should time out") 351 def functionThatWillTimeOut(): 352 """Test decorator of @utils.TimeoutException should timeout.""" 353 time.sleep(5) 354 355 self.assertRaises(errors.FunctionTimeoutError, 356 functionThatWillTimeOut) 357 358 359 def testTimeoutExceptionNoTimeout(self): 360 """Test No TimeoutException.""" 361 @utils.TimeoutException(5, "shouldn't time out") 362 def functionThatShouldNotTimeout(): 363 """Test decorator of @utils.TimeoutException shouldn't timeout.""" 364 return None 365 try: 366 functionThatShouldNotTimeout() 367 except errors.FunctionTimeoutError: 368 self.fail("shouldn't timeout") 369 370 def testAutoConnectCreateSSHTunnelFail(self): 371 """Test auto connect.""" 372 fake_ip_addr = "1.1.1.1" 373 fake_rsa_key_file = "/tmp/rsa_file" 374 fake_target_vnc_port = 8888 375 target_adb_port = 9999 376 ssh_user = "fake_user" 377 call_side_effect = subprocess.CalledProcessError(123, "fake", 378 "fake error") 379 result = utils.ForwardedPorts(vnc_port=None, adb_port=None) 380 self.Patch(subprocess, "check_call", side_effect=call_side_effect) 381 self.assertEqual(result, utils.AutoConnect(fake_ip_addr, 382 fake_rsa_key_file, 383 fake_target_vnc_port, 384 target_adb_port, 385 ssh_user)) 386 387 # pylint: disable=protected-access,no-member 388 def testExtraArgsSSHTunnel(self): 389 """Test extra args will be the same with expanded args.""" 390 fake_ip_addr = "1.1.1.1" 391 fake_rsa_key_file = "/tmp/rsa_file" 392 fake_target_vnc_port = 8888 393 target_adb_port = 9999 394 ssh_user = "fake_user" 395 fake_port = 12345 396 self.Patch(utils, "PickFreePort", return_value=fake_port) 397 self.Patch(utils, "_ExecuteCommand") 398 self.Patch(subprocess, "check_call", return_value=True) 399 extra_args_ssh_tunnel = "-o command='shell %s %h' -o command1='ls -la'" 400 utils.AutoConnect(ip_addr=fake_ip_addr, 401 rsa_key_file=fake_rsa_key_file, 402 target_vnc_port=fake_target_vnc_port, 403 target_adb_port=target_adb_port, 404 ssh_user=ssh_user, 405 client_adb_port=fake_port, 406 extra_args_ssh_tunnel=extra_args_ssh_tunnel) 407 args_list = ["-i", "/tmp/rsa_file", 408 "-o", "UserKnownHostsFile=/dev/null", 409 "-o", "StrictHostKeyChecking=no", 410 "-L", "12345:127.0.0.1:9999", 411 "-L", "12345:127.0.0.1:8888", 412 "-N", "-f", "-l", "fake_user", "1.1.1.1", 413 "-o", "command=shell %s %h", 414 "-o", "command1=ls -la"] 415 first_call_args = utils._ExecuteCommand.call_args_list[0][0] 416 self.assertEqual(first_call_args[1], args_list) 417 418 # pylint: disable=protected-access,no-member 419 def testEstablishWebRTCSshTunnel(self): 420 """Test establish WebRTC ssh tunnel.""" 421 fake_ip_addr = "1.1.1.1" 422 fake_rsa_key_file = "/tmp/rsa_file" 423 ssh_user = "fake_user" 424 self.Patch(utils, "ReleasePort") 425 self.Patch(utils, "_ExecuteCommand") 426 self.Patch(subprocess, "check_call", return_value=True) 427 extra_args_ssh_tunnel = "-o command='shell %s %h' -o command1='ls -la'" 428 utils.EstablishWebRTCSshTunnel( 429 ip_addr=fake_ip_addr, rsa_key_file=fake_rsa_key_file, 430 ssh_user=ssh_user, extra_args_ssh_tunnel=None) 431 args_list = ["-i", "/tmp/rsa_file", 432 "-o", "UserKnownHostsFile=/dev/null", 433 "-o", "StrictHostKeyChecking=no", 434 "-L", "8443:127.0.0.1:8443", 435 "-L", "15550:127.0.0.1:15550", 436 "-L", "15551:127.0.0.1:15551", 437 "-N", "-f", "-l", "fake_user", "1.1.1.1"] 438 first_call_args = utils._ExecuteCommand.call_args_list[0][0] 439 self.assertEqual(first_call_args[1], args_list) 440 441 extra_args_ssh_tunnel = "-o command='shell %s %h'" 442 utils.EstablishWebRTCSshTunnel( 443 ip_addr=fake_ip_addr, rsa_key_file=fake_rsa_key_file, 444 ssh_user=ssh_user, extra_args_ssh_tunnel=extra_args_ssh_tunnel) 445 args_list_with_extra_args = ["-i", "/tmp/rsa_file", 446 "-o", "UserKnownHostsFile=/dev/null", 447 "-o", "StrictHostKeyChecking=no", 448 "-L", "8443:127.0.0.1:8443", 449 "-L", "15550:127.0.0.1:15550", 450 "-L", "15551:127.0.0.1:15551", 451 "-N", "-f", "-l", "fake_user", "1.1.1.1", 452 "-o", "command=shell %s %h"] 453 first_call_args = utils._ExecuteCommand.call_args_list[1][0] 454 self.assertEqual(first_call_args[1], args_list_with_extra_args) 455 456 # pylint: disable=protected-access, no-member 457 def testCleanupSSVncviwer(self): 458 """test cleanup ssvnc viewer.""" 459 fake_vnc_port = 9999 460 fake_ss_vncviewer_pattern = utils._SSVNC_VIEWER_PATTERN % { 461 "vnc_port": fake_vnc_port} 462 self.Patch(utils, "IsCommandRunning", return_value=True) 463 self.Patch(subprocess, "check_call", return_value=True) 464 utils.CleanupSSVncviewer(fake_vnc_port) 465 subprocess.check_call.assert_called_with(["pkill", "-9", "-f", fake_ss_vncviewer_pattern]) 466 467 subprocess.check_call.call_count = 0 468 self.Patch(utils, "IsCommandRunning", return_value=False) 469 utils.CleanupSSVncviewer(fake_vnc_port) 470 subprocess.check_call.assert_not_called() 471 472 def testLaunchBrowserFromReport(self): 473 """test launch browser from report.""" 474 self.Patch(webbrowser, "open_new_tab") 475 fake_report = mock.MagicMock(data={}) 476 477 # test remote instance 478 self.Patch(os.environ, "get", return_value=True) 479 fake_report.data = { 480 "devices": [{"instance_name": "remote_cf_instance_name", 481 "ip": "192.168.1.1",},],} 482 483 utils.LaunchBrowserFromReport(fake_report) 484 webbrowser.open_new_tab.assert_called_once_with("https://localhost:8443") 485 webbrowser.open_new_tab.call_count = 0 486 487 # test local instance 488 fake_report.data = { 489 "devices": [{"instance_name": "local-instance1", 490 "ip": "127.0.0.1:6250",},],} 491 utils.LaunchBrowserFromReport(fake_report) 492 webbrowser.open_new_tab.assert_called_once_with("https://localhost:8443") 493 webbrowser.open_new_tab.call_count = 0 494 495 # verify terminal can't support launch webbrowser. 496 self.Patch(os.environ, "get", return_value=False) 497 utils.LaunchBrowserFromReport(fake_report) 498 self.assertEqual(webbrowser.open_new_tab.call_count, 0) 499 500 def testSetExecutable(self): 501 """test setting a file to be executable.""" 502 with tempfile.NamedTemporaryFile(delete=True) as temp_file: 503 utils.SetExecutable(temp_file.name) 504 self.assertEqual(os.stat(temp_file.name).st_mode & 0o777, 0o755) 505 506 def testSetDirectoryTreeExecutable(self): 507 """test setting a file in a directory to be executable.""" 508 with tempfile.TemporaryDirectory() as temp_dir: 509 subdir = os.path.join(temp_dir, "subdir") 510 file_path = os.path.join(subdir, "file") 511 os.makedirs(subdir) 512 with open(file_path, "w"): 513 pass 514 utils.SetDirectoryTreeExecutable(temp_dir) 515 self.assertEqual(os.stat(file_path).st_mode & 0o777, 0o755) 516 517 518if __name__ == "__main__": 519 unittest.main() 520