1#!/usr/bin/env python 2# 3# Copyright 2019 - The Android Open Source Project 4# 5# Licensed under the Apache License, Version 2.0 (the "License"); 6# you may not use this file except in compliance with the License. 7# You may obtain a copy of the License at 8# 9# http://www.apache.org/licenses/LICENSE-2.0 10# 11# Unless required by applicable law or agreed to in writing, software 12# distributed under the License is distributed on an "AS IS" BASIS, 13# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14# See the License for the specific language governing permissions and 15# limitations under the License. 16 17"""Tests for acloud.internal.lib.ssh.""" 18 19import io 20import subprocess 21import threading 22import time 23import unittest 24 25from unittest import mock 26 27from acloud import errors 28from acloud.internal import constants 29from acloud.internal.lib import driver_test_lib 30from acloud.internal.lib import ssh 31from acloud.internal.lib import utils 32 33 34class SshTest(driver_test_lib.BaseDriverTest): 35 """Test ssh class.""" 36 37 FAKE_SSH_PRIVATE_KEY_PATH = "/fake/acloud_rea" 38 FAKE_SSH_USER = "fake_user" 39 FAKE_IP = ssh.IP(external="1.1.1.1", internal="10.1.1.1") 40 FAKE_EXTRA_ARGS_SSH = "-o ProxyCommand='ssh fake_user@2.2.2.2 Server 22'" 41 FAKE_REPORT_INTERNAL_IP = True 42 43 def setUp(self): 44 """Set up the test.""" 45 super().setUp() 46 self.Patch(utils, "FindExecutable", 47 side_effect=lambda name: f"/usr/bin/{name}") 48 self.created_subprocess = mock.MagicMock() 49 self.created_subprocess.stdout = mock.MagicMock() 50 self.created_subprocess.stdout.readline = mock.MagicMock(return_value=b"") 51 self.created_subprocess.poll = mock.MagicMock(return_value=0) 52 self.created_subprocess.returncode = 0 53 self.created_subprocess.communicate = mock.MagicMock(return_value= 54 ('', '')) 55 56 def testSSHExecuteWithRetry(self): 57 """test SSHExecuteWithRetry method.""" 58 self.Patch(time, "sleep") 59 self.Patch(subprocess, "Popen", 60 side_effect=subprocess.CalledProcessError( 61 None, "ssh command fail.")) 62 self.assertRaises(subprocess.CalledProcessError, 63 ssh.ShellCmdWithRetry, 64 "fake cmd") 65 66 def testGetBaseCmdWithInternalIP(self): 67 """Test get base command with internal ip.""" 68 ssh_object = ssh.Ssh(ip=self.FAKE_IP, 69 user=self.FAKE_SSH_USER, 70 ssh_private_key_path=self.FAKE_SSH_PRIVATE_KEY_PATH, 71 report_internal_ip=self.FAKE_REPORT_INTERNAL_IP) 72 expected_ssh_cmd = ( 73 "/usr/bin/ssh -i /fake/acloud_rea -o LogLevel=ERROR -o ControlPath=none " 74 "-o UserKnownHostsFile=/dev/null -o StrictHostKeyChecking=no " 75 "-l fake_user 10.1.1.1") 76 self.assertEqual(ssh_object.GetBaseCmd(constants.SSH_BIN), expected_ssh_cmd) 77 78 def testGetBaseCmd(self): 79 """Test get base command.""" 80 ssh_object = ssh.Ssh(self.FAKE_IP, self.FAKE_SSH_USER, self.FAKE_SSH_PRIVATE_KEY_PATH) 81 expected_ssh_cmd = ( 82 "/usr/bin/ssh -i /fake/acloud_rea -o LogLevel=ERROR -o ControlPath=none " 83 "-o UserKnownHostsFile=/dev/null -o StrictHostKeyChecking=no " 84 "-l fake_user 1.1.1.1") 85 self.assertEqual(ssh_object.GetBaseCmd(constants.SSH_BIN), expected_ssh_cmd) 86 87 expected_scp_cmd = ( 88 "/usr/bin/scp -i /fake/acloud_rea -o LogLevel=ERROR -o ControlPath=none " 89 "-o UserKnownHostsFile=/dev/null -o StrictHostKeyChecking=no") 90 self.assertEqual(ssh_object.GetBaseCmd(constants.SCP_BIN), expected_scp_cmd) 91 92 # pylint: disable=no-member 93 def testSshRunCmd(self): 94 """Test ssh run command.""" 95 self.Patch(subprocess, "Popen", return_value=self.created_subprocess) 96 self.created_subprocess.communicate.return_value = ("stdout", "") 97 ssh_object = ssh.Ssh(self.FAKE_IP, self.FAKE_SSH_USER, self.FAKE_SSH_PRIVATE_KEY_PATH) 98 self.assertEqual("stdout", ssh_object.Run("command")) 99 expected_cmd = ( 100 "exec /usr/bin/ssh -i /fake/acloud_rea -o LogLevel=ERROR -o ControlPath=none " 101 "-o UserKnownHostsFile=/dev/null -o StrictHostKeyChecking=no " 102 "-l fake_user 1.1.1.1 command") 103 subprocess.Popen.assert_called_with(expected_cmd, 104 shell=True, 105 stderr=-2, 106 stdin=None, 107 stdout=-1, 108 universal_newlines=True) 109 110 def testSshRunCmdwithExtraArgs(self): 111 """test ssh rum command with extra command.""" 112 self.Patch(subprocess, "Popen", return_value=self.created_subprocess) 113 self.created_subprocess.communicate.return_value = ("stdout", "") 114 ssh_object = ssh.Ssh(self.FAKE_IP, 115 self.FAKE_SSH_USER, 116 self.FAKE_SSH_PRIVATE_KEY_PATH, 117 self.FAKE_EXTRA_ARGS_SSH) 118 self.assertEqual("stdout", ssh_object.Run("command")) 119 expected_cmd = ( 120 "exec /usr/bin/ssh -i /fake/acloud_rea -o LogLevel=ERROR -o ControlPath=none " 121 "-o UserKnownHostsFile=/dev/null -o StrictHostKeyChecking=no " 122 "-o ProxyCommand='ssh fake_user@2.2.2.2 Server 22' " 123 "-l fake_user 1.1.1.1 command") 124 subprocess.Popen.assert_called_with(expected_cmd, 125 shell=True, 126 stderr=-2, 127 stdin=None, 128 stdout=-1, 129 universal_newlines=True) 130 131 def testScpPullFileCmd(self): 132 """Test scp pull file command.""" 133 self.Patch(subprocess, "Popen", return_value=self.created_subprocess) 134 ssh_object = ssh.Ssh(self.FAKE_IP, self.FAKE_SSH_USER, self.FAKE_SSH_PRIVATE_KEY_PATH) 135 ssh_object.ScpPullFile("/tmp/test", "/tmp/test_1.log") 136 expected_cmd = ( 137 "exec /usr/bin/scp -i /fake/acloud_rea -o LogLevel=ERROR -o ControlPath=none " 138 "-o UserKnownHostsFile=/dev/null -o StrictHostKeyChecking=no " 139 "fake_user@1.1.1.1:/tmp/test /tmp/test_1.log") 140 subprocess.Popen.assert_called_with(expected_cmd, 141 shell=True, 142 stderr=-2, 143 stdin=None, 144 stdout=-1, 145 universal_newlines=True) 146 147 def testScpPullFileCmdwithExtraArgs(self): 148 """Test scp pull file command.""" 149 self.Patch(subprocess, "Popen", return_value=self.created_subprocess) 150 ssh_object = ssh.Ssh(self.FAKE_IP, 151 self.FAKE_SSH_USER, 152 self.FAKE_SSH_PRIVATE_KEY_PATH, 153 self.FAKE_EXTRA_ARGS_SSH) 154 ssh_object.ScpPullFile("/tmp/test", "/tmp/test_1.log") 155 expected_cmd = ( 156 "exec /usr/bin/scp -i /fake/acloud_rea -o LogLevel=ERROR -o ControlPath=none " 157 "-o UserKnownHostsFile=/dev/null -o StrictHostKeyChecking=no " 158 "-o ProxyCommand='ssh fake_user@2.2.2.2 Server 22' " 159 "fake_user@1.1.1.1:/tmp/test /tmp/test_1.log") 160 subprocess.Popen.assert_called_with(expected_cmd, 161 shell=True, 162 stderr=-2, 163 stdin=None, 164 stdout=-1, 165 universal_newlines=True) 166 167 def testScpPushFileCmd(self): 168 """Test scp push file command.""" 169 self.Patch(subprocess, "Popen", return_value=self.created_subprocess) 170 ssh_object = ssh.Ssh(self.FAKE_IP, self.FAKE_SSH_USER, self.FAKE_SSH_PRIVATE_KEY_PATH) 171 ssh_object.ScpPushFile("/tmp/test", "/tmp/test_1.log") 172 expected_cmd = ( 173 "exec /usr/bin/scp -i /fake/acloud_rea -o LogLevel=ERROR -o ControlPath=none " 174 "-o UserKnownHostsFile=/dev/null -o StrictHostKeyChecking=no " 175 "/tmp/test fake_user@1.1.1.1:/tmp/test_1.log") 176 subprocess.Popen.assert_called_with(expected_cmd, 177 shell=True, 178 stderr=-2, 179 stdin=None, 180 stdout=-1, 181 universal_newlines=True) 182 183 def testScpPushFileCmdwithExtraArgs(self): 184 """Test scp pull file command.""" 185 self.Patch(subprocess, "Popen", return_value=self.created_subprocess) 186 ssh_object = ssh.Ssh(self.FAKE_IP, 187 self.FAKE_SSH_USER, 188 self.FAKE_SSH_PRIVATE_KEY_PATH, 189 self.FAKE_EXTRA_ARGS_SSH) 190 ssh_object.ScpPushFile("/tmp/test", "/tmp/test_1.log") 191 expected_cmd = ( 192 "exec /usr/bin/scp -i /fake/acloud_rea -o LogLevel=ERROR -o ControlPath=none " 193 "-o UserKnownHostsFile=/dev/null -o StrictHostKeyChecking=no " 194 "-o ProxyCommand='ssh fake_user@2.2.2.2 Server 22' " 195 "/tmp/test fake_user@1.1.1.1:/tmp/test_1.log") 196 subprocess.Popen.assert_called_with(expected_cmd, 197 shell=True, 198 stderr=-2, 199 stdin=None, 200 stdout=-1, 201 universal_newlines=True) 202 203 # pylint: disable=protected-access 204 def testIPAddress(self): 205 """Test IP class to get ip address.""" 206 # Internal ip case. 207 ssh_object = ssh.Ssh(ip=ssh.IP(external="1.1.1.1", internal="10.1.1.1"), 208 user=self.FAKE_SSH_USER, 209 ssh_private_key_path=self.FAKE_SSH_PRIVATE_KEY_PATH, 210 report_internal_ip=True) 211 expected_ip = "10.1.1.1" 212 self.assertEqual(ssh_object._ip, expected_ip) 213 214 # External ip case. 215 ssh_object = ssh.Ssh(ip=ssh.IP(external="1.1.1.1", internal="10.1.1.1"), 216 user=self.FAKE_SSH_USER, 217 ssh_private_key_path=self.FAKE_SSH_PRIVATE_KEY_PATH) 218 expected_ip = "1.1.1.1" 219 self.assertEqual(ssh_object._ip, expected_ip) 220 221 # Only one ip case. 222 ssh_object = ssh.Ssh(ip=ssh.IP(ip="1.1.1.1"), 223 user=self.FAKE_SSH_USER, 224 ssh_private_key_path=self.FAKE_SSH_PRIVATE_KEY_PATH) 225 expected_ip = "1.1.1.1" 226 self.assertEqual(ssh_object._ip, expected_ip) 227 228 def testWaitForSsh(self): 229 """Test WaitForSsh.""" 230 ssh_object = ssh.Ssh(ip=self.FAKE_IP, 231 user=self.FAKE_SSH_USER, 232 ssh_private_key_path=self.FAKE_SSH_PRIVATE_KEY_PATH, 233 report_internal_ip=self.FAKE_REPORT_INTERNAL_IP) 234 self.created_subprocess.returncode = -1 235 self.Patch(subprocess, "Popen", return_value=self.created_subprocess) 236 self.assertRaises(subprocess.CalledProcessError, 237 ssh_object.WaitForSsh, 238 timeout=1, 239 max_retry=1) 240 241 def testSshCallWait(self): 242 """Test SshCallWait.""" 243 self.Patch(subprocess, "Popen", return_value=self.created_subprocess) 244 self.Patch(threading, "Timer") 245 fake_cmd = "fake command" 246 ssh._SshCallWait(fake_cmd) 247 threading.Timer.assert_not_called() 248 249 def testSshCallWaitTimeout(self): 250 """Test SshCallWait with timeout.""" 251 self.Patch(subprocess, "Popen", return_value=self.created_subprocess) 252 self.Patch(threading, "Timer") 253 fake_cmd = "fake command" 254 fake_timeout = 30 255 ssh._SshCallWait(fake_cmd, fake_timeout) 256 threading.Timer.assert_called_once() 257 258 def testSshCall(self): 259 """Test _SshCall.""" 260 self.Patch(subprocess, "Popen", return_value=self.created_subprocess) 261 self.Patch(threading, "Timer") 262 fake_cmd = "fake command" 263 ssh._SshCall(fake_cmd) 264 threading.Timer.assert_not_called() 265 266 def testSshCallTimeout(self): 267 """Test SshCallWait with timeout.""" 268 self.Patch(subprocess, "Popen", return_value=self.created_subprocess) 269 self.Patch(threading, "Timer") 270 fake_cmd = "fake command" 271 fake_timeout = 30 272 ssh._SshCall(fake_cmd, fake_timeout) 273 threading.Timer.assert_called_once() 274 275 def testSshLogOutput(self): 276 """Test _SshCall.""" 277 self.Patch(subprocess, "Popen", return_value=self.created_subprocess) 278 self.Patch(threading, "Timer") 279 fake_cmd = "fake command" 280 ssh._SshLogOutput(fake_cmd) 281 threading.Timer.assert_not_called() 282 283 # Test with all kind of exceptions. 284 self.created_subprocess.returncode = 255 285 self.assertRaises( 286 errors.DeviceConnectionError, ssh._SshLogOutput, fake_cmd) 287 288 self.created_subprocess.returncode = -1 289 self.assertRaises( 290 subprocess.CalledProcessError, ssh._SshLogOutput, fake_cmd) 291 292 with mock.patch("sys.stderr", new=io.StringIO()): 293 self.created_subprocess.communicate = mock.MagicMock( 294 return_value=(constants.ERROR_MSG_VNC_NOT_SUPPORT, '')) 295 self.assertRaises( 296 errors.LaunchCVDFail, ssh._SshLogOutput, fake_cmd) 297 298 with mock.patch("sys.stderr", new=io.StringIO()): 299 self.created_subprocess.communicate = mock.MagicMock( 300 return_value=(constants.ERROR_MSG_WEBRTC_NOT_SUPPORT, '')) 301 self.assertRaises( 302 errors.LaunchCVDFail, ssh._SshLogOutput, fake_cmd) 303 304 def testSshLogOutputTimeout(self): 305 """Test SshCallWait with timeout.""" 306 self.Patch(subprocess, "Popen", return_value=self.created_subprocess) 307 self.Patch(threading, "Timer") 308 fake_cmd = "fake command" 309 fake_timeout = 30 310 ssh._SshLogOutput(fake_cmd, fake_timeout) 311 threading.Timer.assert_called_once() 312 313 def testGetErrorMessage(self): 314 """Test _GetErrorMessage.""" 315 # should return response 316 fake_output = """ 317fetch_cvd E 10-25 09:45:44 1337 1337 build_api.cc:184] URL endpoint did not have json path: { 318fetch_cvd E 10-25 09:45:44 1337 1337 build_api.cc:184] "error" : "Failed to parse json.", 319fetch_cvd E 10-25 09:45:44 1337 1337 build_api.cc:184] "response" : "fake_error_response" 320fetch_cvd E 10-25 09:45:44 1337 1337 build_api.cc:184] } 321fetch_cvd E 10-25 09:45:44 1337 1337 fetch_cvd.cc:102] Unable to download.""" 322 self.assertEqual(ssh._GetErrorMessage(fake_output), "fake_error_response") 323 324 # should return message only 325 fake_output = """ 326fetch_cvd F 11-15 07:34:13 2169 2169 build_api.cc:164] Error fetching the artifacts 327fetch_cvd F 11-15 07:34:13 2169 2169 build_api.cc:164] "error" : 328fetch_cvd F 11-15 07:34:13 2169 2169 build_api.cc:164] { 329fetch_cvd F 11-15 07:34:13 2169 2169 build_api.cc:164] "code" : 500, 330fetch_cvd F 11-15 07:34:13 2169 2169 build_api.cc:164] "errors" : 331fetch_cvd F 11-15 07:34:13 2169 2169 build_api.cc:164] [ 332fetch_cvd F 11-15 07:34:13 2169 2169 build_api.cc:164] {} 333fetch_cvd F 11-15 07:34:13 2169 2169 build_api.cc:164] ], 334fetch_cvd F 11-15 07:34:13 2169 2169 build_api.cc:164] "message" : "Unknown Error.", 335fetch_cvd F 11-15 07:34:13 2169 2169 build_api.cc:164] "status" : "UNKNOWN" 336fetch_cvd F 11-15 07:34:13 2169 2169 build_api.cc:164] } 337fetch_cvd F 11-15 07:34:13 2169 2169 build_api.cc:164] }", and code was 500Fail! (320s)""" 338 self.assertEqual(ssh._GetErrorMessage(fake_output), "Unknown Error.") 339 340 # should output last 10 line 341 fake_output = """ 342fetch_cvd F 11-15 07:34:13 2169 2169 build_api.cc:164] Error fetching the artifacts of { 343fetch_cvd F 11-15 07:34:13 2169 2169 build_api.cc:164] "error" : 344fetch_cvd F 11-15 07:34:13 2169 2169 build_api.cc:164] { 345fetch_cvd F 11-15 07:34:13 2169 2169 build_api.cc:164] "code" : 500, 346fetch_cvd F 11-15 07:34:13 2169 2169 build_api.cc:164] "errors" : 347fetch_cvd F 11-15 07:34:13 2169 2169 build_api.cc:164] [ 348fetch_cvd F 11-15 07:34:13 2169 2169 build_api.cc:164] {} 349fetch_cvd F 11-15 07:34:13 2169 2169 build_api.cc:164] ], 350fetch_cvd F 11-15 07:34:13 2169 2169 build_api.cc:164] "status" : "UNKNOWN" 351fetch_cvd F 11-15 07:34:13 2169 2169 build_api.cc:164] } 352fetch_cvd F 11-15 07:34:13 2169 2169 build_api.cc:164] }", and code was 500Fail! (320s)""" 353 self.assertEqual(ssh._GetErrorMessage(fake_output), "\n".join( 354 fake_output.splitlines()[-10::])) 355 356 def testFilterUnusedContent(self): 357 """Test _FilterUnusedContent.""" 358 # should remove html, !, title, span, a, p, b, style, ins, code, \n 359 fake_content = ("<!DOCTYPE html><html lang=en>\\n<meta charset=utf-8>" 360 "<title>Error</title>\\n<style>*{padding:0}html}</style>" 361 "<a href=//www.google.com/><span id=logo></span></a>" 362 "<p><b>404.</b> <ins>That\u2019s an error.</ins><p>" 363 "The requested URL was not found on this server <code>" 364 "url/id</code> <ins>That\u2019s all we know.</ins>\\n") 365 expected = (" Error 404. That’s an error.The requested URL was not" 366 " found on this server url/id That’s all we know. ") 367 self.assertEqual(ssh._FilterUnusedContent(fake_content), expected) 368 369 370if __name__ == "__main__": 371 unittest.main() 372