1#!/usr/bin/env python 2# 3# Copyright 2019 - The Android Open Source Project 4# 5# Licensed under the Apache License, Version 2.0 (the "License"); 6# you may not use this file except in compliance with the License. 7# You may obtain a copy of the License at 8# 9# http://www.apache.org/licenses/LICENSE-2.0 10# 11# Unless required by applicable law or agreed to in writing, software 12# distributed under the License is distributed on an "AS IS" BASIS, 13# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14# See the License for the specific language governing permissions and 15# limitations under the License. 16 17"""Tests for acloud.internal.lib.ssh.""" 18 19import subprocess 20import unittest 21import mock 22 23from acloud import errors 24from acloud.internal import constants 25from acloud.internal.lib import driver_test_lib 26from acloud.internal.lib import ssh 27 28 29class SshTest(driver_test_lib.BaseDriverTest): 30 """Test ssh class.""" 31 32 FAKE_SSH_PRIVATE_KEY_PATH = "/fake/acloud_rea" 33 FAKE_SSH_USER = "fake_user" 34 FAKE_IP = ssh.IP(external="1.1.1.1", internal="10.1.1.1") 35 FAKE_EXTRA_ARGS_SSH = "-o ProxyCommand='ssh fake_user@2.2.2.2 Server 22'" 36 FAKE_REPORT_INTERNAL_IP = True 37 38 def setUp(self): 39 """Set up the test.""" 40 super(SshTest, self).setUp() 41 self.created_subprocess = mock.MagicMock() 42 self.created_subprocess.stdout = mock.MagicMock() 43 self.created_subprocess.stdout.readline = mock.MagicMock(return_value='') 44 self.created_subprocess.poll = mock.MagicMock(return_value=0) 45 self.created_subprocess.returncode = 0 46 self.created_subprocess.communicate = mock.MagicMock(return_value= 47 ('', '')) 48 49 def testSSHExecuteWithRetry(self): 50 """test SSHExecuteWithRetry method.""" 51 self.Patch(subprocess, "Popen", 52 side_effect=subprocess.CalledProcessError( 53 None, "ssh command fail.")) 54 self.assertRaises(subprocess.CalledProcessError, 55 ssh.ShellCmdWithRetry, 56 "fake cmd") 57 58 def testGetBaseCmdWithInternalIP(self): 59 """Test get base command with internal ip.""" 60 ssh_object = ssh.Ssh(ip=self.FAKE_IP, 61 user=self.FAKE_SSH_USER, 62 ssh_private_key_path=self.FAKE_SSH_PRIVATE_KEY_PATH, 63 report_internal_ip=self.FAKE_REPORT_INTERNAL_IP) 64 expected_ssh_cmd = ("/usr/bin/ssh -i /fake/acloud_rea -q -o UserKnownHostsFile=/dev/null " 65 "-o StrictHostKeyChecking=no -l fake_user 10.1.1.1") 66 self.assertEqual(ssh_object.GetBaseCmd(constants.SSH_BIN), expected_ssh_cmd) 67 68 def testGetBaseCmd(self): 69 """Test get base command.""" 70 ssh_object = ssh.Ssh(self.FAKE_IP, self.FAKE_SSH_USER, self.FAKE_SSH_PRIVATE_KEY_PATH) 71 expected_ssh_cmd = ("/usr/bin/ssh -i /fake/acloud_rea -q -o UserKnownHostsFile=/dev/null " 72 "-o StrictHostKeyChecking=no -l fake_user 1.1.1.1") 73 self.assertEqual(ssh_object.GetBaseCmd(constants.SSH_BIN), expected_ssh_cmd) 74 75 expected_scp_cmd = ("/usr/bin/scp -i /fake/acloud_rea -q -o UserKnownHostsFile=/dev/null " 76 "-o StrictHostKeyChecking=no") 77 self.assertEqual(ssh_object.GetBaseCmd(constants.SCP_BIN), expected_scp_cmd) 78 79 # pylint: disable=no-member 80 def testSshRunCmd(self): 81 """Test ssh run command.""" 82 self.Patch(subprocess, "Popen", return_value=self.created_subprocess) 83 ssh_object = ssh.Ssh(self.FAKE_IP, self.FAKE_SSH_USER, self.FAKE_SSH_PRIVATE_KEY_PATH) 84 ssh_object.Run("command") 85 expected_cmd = ("exec /usr/bin/ssh -i /fake/acloud_rea -q -o UserKnownHostsFile=/dev/null " 86 "-o StrictHostKeyChecking=no -l fake_user 1.1.1.1 command") 87 subprocess.Popen.assert_called_with(expected_cmd, 88 shell=True, 89 stderr=-2, 90 stdin=None, 91 stdout=-1) 92 93 def testSshRunCmdwithExtraArgs(self): 94 """test ssh rum command with extra command.""" 95 self.Patch(subprocess, "Popen", return_value=self.created_subprocess) 96 ssh_object = ssh.Ssh(self.FAKE_IP, 97 self.FAKE_SSH_USER, 98 self.FAKE_SSH_PRIVATE_KEY_PATH, 99 self.FAKE_EXTRA_ARGS_SSH) 100 ssh_object.Run("command") 101 expected_cmd = ("exec /usr/bin/ssh -i /fake/acloud_rea -q -o UserKnownHostsFile=/dev/null " 102 "-o StrictHostKeyChecking=no " 103 "-o ProxyCommand='ssh fake_user@2.2.2.2 Server 22' " 104 "-l fake_user 1.1.1.1 command") 105 subprocess.Popen.assert_called_with(expected_cmd, 106 shell=True, 107 stderr=-2, 108 stdin=None, 109 stdout=-1) 110 111 def testScpPullFileCmd(self): 112 """Test scp pull file command.""" 113 self.Patch(subprocess, "Popen", return_value=self.created_subprocess) 114 ssh_object = ssh.Ssh(self.FAKE_IP, self.FAKE_SSH_USER, self.FAKE_SSH_PRIVATE_KEY_PATH) 115 ssh_object.ScpPullFile("/tmp/test", "/tmp/test_1.log") 116 expected_cmd = ("exec /usr/bin/scp -i /fake/acloud_rea -q -o UserKnownHostsFile=/dev/null " 117 "-o StrictHostKeyChecking=no fake_user@1.1.1.1:/tmp/test /tmp/test_1.log") 118 subprocess.Popen.assert_called_with(expected_cmd, 119 shell=True, 120 stderr=-2, 121 stdin=None, 122 stdout=-1) 123 124 def testScpPullFileCmdwithExtraArgs(self): 125 """Test scp pull file command.""" 126 self.Patch(subprocess, "Popen", return_value=self.created_subprocess) 127 ssh_object = ssh.Ssh(self.FAKE_IP, 128 self.FAKE_SSH_USER, 129 self.FAKE_SSH_PRIVATE_KEY_PATH, 130 self.FAKE_EXTRA_ARGS_SSH) 131 ssh_object.ScpPullFile("/tmp/test", "/tmp/test_1.log") 132 expected_cmd = ("exec /usr/bin/scp -i /fake/acloud_rea -q -o UserKnownHostsFile=/dev/null " 133 "-o StrictHostKeyChecking=no " 134 "-o ProxyCommand='ssh fake_user@2.2.2.2 Server 22' " 135 "fake_user@1.1.1.1:/tmp/test /tmp/test_1.log") 136 subprocess.Popen.assert_called_with(expected_cmd, 137 shell=True, 138 stderr=-2, 139 stdin=None, 140 stdout=-1) 141 142 def testScpPushFileCmd(self): 143 """Test scp push file command.""" 144 self.Patch(subprocess, "Popen", return_value=self.created_subprocess) 145 ssh_object = ssh.Ssh(self.FAKE_IP, self.FAKE_SSH_USER, self.FAKE_SSH_PRIVATE_KEY_PATH) 146 ssh_object.ScpPushFile("/tmp/test", "/tmp/test_1.log") 147 expected_cmd = ("exec /usr/bin/scp -i /fake/acloud_rea -q -o UserKnownHostsFile=/dev/null " 148 "-o StrictHostKeyChecking=no /tmp/test fake_user@1.1.1.1:/tmp/test_1.log") 149 subprocess.Popen.assert_called_with(expected_cmd, 150 shell=True, 151 stderr=-2, 152 stdin=None, 153 stdout=-1) 154 155 def testScpPushFileCmdwithExtraArgs(self): 156 """Test scp pull file command.""" 157 self.Patch(subprocess, "Popen", return_value=self.created_subprocess) 158 ssh_object = ssh.Ssh(self.FAKE_IP, 159 self.FAKE_SSH_USER, 160 self.FAKE_SSH_PRIVATE_KEY_PATH, 161 self.FAKE_EXTRA_ARGS_SSH) 162 ssh_object.ScpPushFile("/tmp/test", "/tmp/test_1.log") 163 expected_cmd = ("exec /usr/bin/scp -i /fake/acloud_rea -q -o UserKnownHostsFile=/dev/null " 164 "-o StrictHostKeyChecking=no " 165 "-o ProxyCommand='ssh fake_user@2.2.2.2 Server 22' " 166 "/tmp/test fake_user@1.1.1.1:/tmp/test_1.log") 167 subprocess.Popen.assert_called_with(expected_cmd, 168 shell=True, 169 stderr=-2, 170 stdin=None, 171 stdout=-1) 172 173 # pylint: disable=protected-access 174 def testIPAddress(self): 175 """Test IP class to get ip address.""" 176 # Internal ip case. 177 ssh_object = ssh.Ssh(ip=ssh.IP(external="1.1.1.1", internal="10.1.1.1"), 178 user=self.FAKE_SSH_USER, 179 ssh_private_key_path=self.FAKE_SSH_PRIVATE_KEY_PATH, 180 report_internal_ip=True) 181 expected_ip = "10.1.1.1" 182 self.assertEqual(ssh_object._ip, expected_ip) 183 184 # External ip case. 185 ssh_object = ssh.Ssh(ip=ssh.IP(external="1.1.1.1", internal="10.1.1.1"), 186 user=self.FAKE_SSH_USER, 187 ssh_private_key_path=self.FAKE_SSH_PRIVATE_KEY_PATH) 188 expected_ip = "1.1.1.1" 189 self.assertEqual(ssh_object._ip, expected_ip) 190 191 # Only one ip case. 192 ssh_object = ssh.Ssh(ip=ssh.IP(ip="1.1.1.1"), 193 user=self.FAKE_SSH_USER, 194 ssh_private_key_path=self.FAKE_SSH_PRIVATE_KEY_PATH) 195 expected_ip = "1.1.1.1" 196 self.assertEqual(ssh_object._ip, expected_ip) 197 198 def testWaitForSsh(self): 199 """Test WaitForSsh.""" 200 ssh_object = ssh.Ssh(ip=self.FAKE_IP, 201 user=self.FAKE_SSH_USER, 202 ssh_private_key_path=self.FAKE_SSH_PRIVATE_KEY_PATH, 203 report_internal_ip=self.FAKE_REPORT_INTERNAL_IP) 204 self.Patch(ssh, "_SshCall", return_value=-1) 205 self.assertRaises(errors.DeviceConnectionError, 206 ssh_object.WaitForSsh, 207 timeout=1, 208 sleep_for_retry=1, 209 max_retry=1) 210 211 212if __name__ == "__main__": 213 unittest.main() 214