• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1#!/usr/bin/env python
2#
3# Copyright 2019 - The Android Open Source Project
4#
5# Licensed under the Apache License, Version 2.0 (the "License");
6# you may not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9#     http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an "AS IS" BASIS,
13# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
16
17"""Tests for acloud.internal.lib.ssh."""
18
19import subprocess
20import unittest
21import mock
22
23from acloud import errors
24from acloud.internal import constants
25from acloud.internal.lib import driver_test_lib
26from acloud.internal.lib import ssh
27
28
29class SshTest(driver_test_lib.BaseDriverTest):
30    """Test ssh class."""
31
32    FAKE_SSH_PRIVATE_KEY_PATH = "/fake/acloud_rea"
33    FAKE_SSH_USER = "fake_user"
34    FAKE_IP = ssh.IP(external="1.1.1.1", internal="10.1.1.1")
35    FAKE_EXTRA_ARGS_SSH = "-o ProxyCommand='ssh fake_user@2.2.2.2 Server 22'"
36    FAKE_REPORT_INTERNAL_IP = True
37
38    def setUp(self):
39        """Set up the test."""
40        super(SshTest, self).setUp()
41        self.created_subprocess = mock.MagicMock()
42        self.created_subprocess.stdout = mock.MagicMock()
43        self.created_subprocess.stdout.readline = mock.MagicMock(return_value='')
44        self.created_subprocess.poll = mock.MagicMock(return_value=0)
45        self.created_subprocess.returncode = 0
46        self.created_subprocess.communicate = mock.MagicMock(return_value=
47                                                             ('', ''))
48
49    def testSSHExecuteWithRetry(self):
50        """test SSHExecuteWithRetry method."""
51        self.Patch(subprocess, "Popen",
52                   side_effect=subprocess.CalledProcessError(
53                       None, "ssh command fail."))
54        self.assertRaises(subprocess.CalledProcessError,
55                          ssh.ShellCmdWithRetry,
56                          "fake cmd")
57
58    def testGetBaseCmdWithInternalIP(self):
59        """Test get base command with internal ip."""
60        ssh_object = ssh.Ssh(ip=self.FAKE_IP,
61                             user=self.FAKE_SSH_USER,
62                             ssh_private_key_path=self.FAKE_SSH_PRIVATE_KEY_PATH,
63                             report_internal_ip=self.FAKE_REPORT_INTERNAL_IP)
64        expected_ssh_cmd = ("/usr/bin/ssh -i /fake/acloud_rea -q -o UserKnownHostsFile=/dev/null "
65                            "-o StrictHostKeyChecking=no -l fake_user 10.1.1.1")
66        self.assertEqual(ssh_object.GetBaseCmd(constants.SSH_BIN), expected_ssh_cmd)
67
68    def testGetBaseCmd(self):
69        """Test get base command."""
70        ssh_object = ssh.Ssh(self.FAKE_IP, self.FAKE_SSH_USER, self.FAKE_SSH_PRIVATE_KEY_PATH)
71        expected_ssh_cmd = ("/usr/bin/ssh -i /fake/acloud_rea -q -o UserKnownHostsFile=/dev/null "
72                            "-o StrictHostKeyChecking=no -l fake_user 1.1.1.1")
73        self.assertEqual(ssh_object.GetBaseCmd(constants.SSH_BIN), expected_ssh_cmd)
74
75        expected_scp_cmd = ("/usr/bin/scp -i /fake/acloud_rea -q -o UserKnownHostsFile=/dev/null "
76                            "-o StrictHostKeyChecking=no")
77        self.assertEqual(ssh_object.GetBaseCmd(constants.SCP_BIN), expected_scp_cmd)
78
79    # pylint: disable=no-member
80    def testSshRunCmd(self):
81        """Test ssh run command."""
82        self.Patch(subprocess, "Popen", return_value=self.created_subprocess)
83        ssh_object = ssh.Ssh(self.FAKE_IP, self.FAKE_SSH_USER, self.FAKE_SSH_PRIVATE_KEY_PATH)
84        ssh_object.Run("command")
85        expected_cmd = ("exec /usr/bin/ssh -i /fake/acloud_rea -q -o UserKnownHostsFile=/dev/null "
86                        "-o StrictHostKeyChecking=no -l fake_user 1.1.1.1 command")
87        subprocess.Popen.assert_called_with(expected_cmd,
88                                            shell=True,
89                                            stderr=-2,
90                                            stdin=None,
91                                            stdout=-1)
92
93    def testSshRunCmdwithExtraArgs(self):
94        """test ssh rum command with extra command."""
95        self.Patch(subprocess, "Popen", return_value=self.created_subprocess)
96        ssh_object = ssh.Ssh(self.FAKE_IP,
97                             self.FAKE_SSH_USER,
98                             self.FAKE_SSH_PRIVATE_KEY_PATH,
99                             self.FAKE_EXTRA_ARGS_SSH)
100        ssh_object.Run("command")
101        expected_cmd = ("exec /usr/bin/ssh -i /fake/acloud_rea -q -o UserKnownHostsFile=/dev/null "
102                        "-o StrictHostKeyChecking=no "
103                        "-o ProxyCommand='ssh fake_user@2.2.2.2 Server 22' "
104                        "-l fake_user 1.1.1.1 command")
105        subprocess.Popen.assert_called_with(expected_cmd,
106                                            shell=True,
107                                            stderr=-2,
108                                            stdin=None,
109                                            stdout=-1)
110
111    def testScpPullFileCmd(self):
112        """Test scp pull file command."""
113        self.Patch(subprocess, "Popen", return_value=self.created_subprocess)
114        ssh_object = ssh.Ssh(self.FAKE_IP, self.FAKE_SSH_USER, self.FAKE_SSH_PRIVATE_KEY_PATH)
115        ssh_object.ScpPullFile("/tmp/test", "/tmp/test_1.log")
116        expected_cmd = ("exec /usr/bin/scp -i /fake/acloud_rea -q -o UserKnownHostsFile=/dev/null "
117                        "-o StrictHostKeyChecking=no fake_user@1.1.1.1:/tmp/test /tmp/test_1.log")
118        subprocess.Popen.assert_called_with(expected_cmd,
119                                            shell=True,
120                                            stderr=-2,
121                                            stdin=None,
122                                            stdout=-1)
123
124    def testScpPullFileCmdwithExtraArgs(self):
125        """Test scp pull file command."""
126        self.Patch(subprocess, "Popen", return_value=self.created_subprocess)
127        ssh_object = ssh.Ssh(self.FAKE_IP,
128                             self.FAKE_SSH_USER,
129                             self.FAKE_SSH_PRIVATE_KEY_PATH,
130                             self.FAKE_EXTRA_ARGS_SSH)
131        ssh_object.ScpPullFile("/tmp/test", "/tmp/test_1.log")
132        expected_cmd = ("exec /usr/bin/scp -i /fake/acloud_rea -q -o UserKnownHostsFile=/dev/null "
133                        "-o StrictHostKeyChecking=no "
134                        "-o ProxyCommand='ssh fake_user@2.2.2.2 Server 22' "
135                        "fake_user@1.1.1.1:/tmp/test /tmp/test_1.log")
136        subprocess.Popen.assert_called_with(expected_cmd,
137                                            shell=True,
138                                            stderr=-2,
139                                            stdin=None,
140                                            stdout=-1)
141
142    def testScpPushFileCmd(self):
143        """Test scp push file command."""
144        self.Patch(subprocess, "Popen", return_value=self.created_subprocess)
145        ssh_object = ssh.Ssh(self.FAKE_IP, self.FAKE_SSH_USER, self.FAKE_SSH_PRIVATE_KEY_PATH)
146        ssh_object.ScpPushFile("/tmp/test", "/tmp/test_1.log")
147        expected_cmd = ("exec /usr/bin/scp -i /fake/acloud_rea -q -o UserKnownHostsFile=/dev/null "
148                        "-o StrictHostKeyChecking=no /tmp/test fake_user@1.1.1.1:/tmp/test_1.log")
149        subprocess.Popen.assert_called_with(expected_cmd,
150                                            shell=True,
151                                            stderr=-2,
152                                            stdin=None,
153                                            stdout=-1)
154
155    def testScpPushFileCmdwithExtraArgs(self):
156        """Test scp pull file command."""
157        self.Patch(subprocess, "Popen", return_value=self.created_subprocess)
158        ssh_object = ssh.Ssh(self.FAKE_IP,
159                             self.FAKE_SSH_USER,
160                             self.FAKE_SSH_PRIVATE_KEY_PATH,
161                             self.FAKE_EXTRA_ARGS_SSH)
162        ssh_object.ScpPushFile("/tmp/test", "/tmp/test_1.log")
163        expected_cmd = ("exec /usr/bin/scp -i /fake/acloud_rea -q -o UserKnownHostsFile=/dev/null "
164                        "-o StrictHostKeyChecking=no "
165                        "-o ProxyCommand='ssh fake_user@2.2.2.2 Server 22' "
166                        "/tmp/test fake_user@1.1.1.1:/tmp/test_1.log")
167        subprocess.Popen.assert_called_with(expected_cmd,
168                                            shell=True,
169                                            stderr=-2,
170                                            stdin=None,
171                                            stdout=-1)
172
173    # pylint: disable=protected-access
174    def testIPAddress(self):
175        """Test IP class to get ip address."""
176        # Internal ip case.
177        ssh_object = ssh.Ssh(ip=ssh.IP(external="1.1.1.1", internal="10.1.1.1"),
178                             user=self.FAKE_SSH_USER,
179                             ssh_private_key_path=self.FAKE_SSH_PRIVATE_KEY_PATH,
180                             report_internal_ip=True)
181        expected_ip = "10.1.1.1"
182        self.assertEqual(ssh_object._ip, expected_ip)
183
184        # External ip case.
185        ssh_object = ssh.Ssh(ip=ssh.IP(external="1.1.1.1", internal="10.1.1.1"),
186                             user=self.FAKE_SSH_USER,
187                             ssh_private_key_path=self.FAKE_SSH_PRIVATE_KEY_PATH)
188        expected_ip = "1.1.1.1"
189        self.assertEqual(ssh_object._ip, expected_ip)
190
191        # Only one ip case.
192        ssh_object = ssh.Ssh(ip=ssh.IP(ip="1.1.1.1"),
193                             user=self.FAKE_SSH_USER,
194                             ssh_private_key_path=self.FAKE_SSH_PRIVATE_KEY_PATH)
195        expected_ip = "1.1.1.1"
196        self.assertEqual(ssh_object._ip, expected_ip)
197
198    def testWaitForSsh(self):
199        """Test WaitForSsh."""
200        ssh_object = ssh.Ssh(ip=self.FAKE_IP,
201                             user=self.FAKE_SSH_USER,
202                             ssh_private_key_path=self.FAKE_SSH_PRIVATE_KEY_PATH,
203                             report_internal_ip=self.FAKE_REPORT_INTERNAL_IP)
204        self.Patch(ssh, "_SshCall", return_value=-1)
205        self.assertRaises(errors.DeviceConnectionError,
206                          ssh_object.WaitForSsh,
207                          timeout=1,
208                          sleep_for_retry=1,
209                          max_retry=1)
210
211
212if __name__ == "__main__":
213    unittest.main()
214