• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1#!/usr/bin/env python
2#
3# Copyright 2019 - The Android Open Source Project
4#
5# Licensed under the Apache License, Version 2.0 (the "License");
6# you may not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9#     http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an "AS IS" BASIS,
13# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
16
17"""Tests for acloud.internal.lib.ssh."""
18
19import io
20import subprocess
21import threading
22import time
23import unittest
24
25from unittest import mock
26
27from acloud import errors
28from acloud.internal import constants
29from acloud.internal.lib import driver_test_lib
30from acloud.internal.lib import ssh
31from acloud.internal.lib import utils
32
33
34class SshTest(driver_test_lib.BaseDriverTest):
35    """Test ssh class."""
36
37    FAKE_SSH_PRIVATE_KEY_PATH = "/fake/acloud_rea"
38    FAKE_SSH_USER = "fake_user"
39    FAKE_IP = ssh.IP(external="1.1.1.1", internal="10.1.1.1")
40    FAKE_EXTRA_ARGS_SSH = "-o ProxyCommand='ssh fake_user@2.2.2.2 Server 22'"
41    FAKE_REPORT_INTERNAL_IP = True
42
43    def setUp(self):
44        """Set up the test."""
45        super().setUp()
46        self.Patch(utils, "FindExecutable",
47                   side_effect=lambda name: f"/usr/bin/{name}")
48        self.created_subprocess = mock.MagicMock()
49        self.created_subprocess.stdout = mock.MagicMock()
50        self.created_subprocess.stdout.readline = mock.MagicMock(return_value=b"")
51        self.created_subprocess.poll = mock.MagicMock(return_value=0)
52        self.created_subprocess.returncode = 0
53        self.created_subprocess.communicate = mock.MagicMock(return_value=
54                                                             ('', ''))
55
56    def testSSHExecuteWithRetry(self):
57        """test SSHExecuteWithRetry method."""
58        self.Patch(time, "sleep")
59        self.Patch(subprocess, "Popen",
60                   side_effect=subprocess.CalledProcessError(
61                       None, "ssh command fail."))
62        self.assertRaises(subprocess.CalledProcessError,
63                          ssh.ShellCmdWithRetry,
64                          "fake cmd")
65
66    def testGetBaseCmdWithInternalIP(self):
67        """Test get base command with internal ip."""
68        ssh_object = ssh.Ssh(ip=self.FAKE_IP,
69                             user=self.FAKE_SSH_USER,
70                             ssh_private_key_path=self.FAKE_SSH_PRIVATE_KEY_PATH,
71                             report_internal_ip=self.FAKE_REPORT_INTERNAL_IP)
72        expected_ssh_cmd = (
73            "/usr/bin/ssh -i /fake/acloud_rea -o LogLevel=ERROR -o ControlPath=none "
74            "-o UserKnownHostsFile=/dev/null -o StrictHostKeyChecking=no "
75            "-l fake_user 10.1.1.1")
76        self.assertEqual(ssh_object.GetBaseCmd(constants.SSH_BIN), expected_ssh_cmd)
77
78    def testGetBaseCmd(self):
79        """Test get base command."""
80        ssh_object = ssh.Ssh(self.FAKE_IP, self.FAKE_SSH_USER, self.FAKE_SSH_PRIVATE_KEY_PATH)
81        expected_ssh_cmd = (
82            "/usr/bin/ssh -i /fake/acloud_rea -o LogLevel=ERROR -o ControlPath=none "
83            "-o UserKnownHostsFile=/dev/null -o StrictHostKeyChecking=no "
84            "-l fake_user 1.1.1.1")
85        self.assertEqual(ssh_object.GetBaseCmd(constants.SSH_BIN), expected_ssh_cmd)
86
87        expected_scp_cmd = (
88            "/usr/bin/scp -i /fake/acloud_rea -o LogLevel=ERROR -o ControlPath=none "
89            "-o UserKnownHostsFile=/dev/null -o StrictHostKeyChecking=no")
90        self.assertEqual(ssh_object.GetBaseCmd(constants.SCP_BIN), expected_scp_cmd)
91
92    # pylint: disable=no-member
93    def testSshRunCmd(self):
94        """Test ssh run command."""
95        self.Patch(subprocess, "Popen", return_value=self.created_subprocess)
96        self.created_subprocess.communicate.return_value = ("stdout", "")
97        ssh_object = ssh.Ssh(self.FAKE_IP, self.FAKE_SSH_USER, self.FAKE_SSH_PRIVATE_KEY_PATH)
98        self.assertEqual("stdout", ssh_object.Run("command"))
99        expected_cmd = (
100            "exec /usr/bin/ssh -i /fake/acloud_rea -o LogLevel=ERROR -o ControlPath=none "
101            "-o UserKnownHostsFile=/dev/null -o StrictHostKeyChecking=no "
102            "-l fake_user 1.1.1.1 command")
103        subprocess.Popen.assert_called_with(expected_cmd,
104                                            shell=True,
105                                            stderr=-2,
106                                            stdin=None,
107                                            stdout=-1,
108                                            universal_newlines=True)
109
110    def testSshRunCmdwithExtraArgs(self):
111        """test ssh rum command with extra command."""
112        self.Patch(subprocess, "Popen", return_value=self.created_subprocess)
113        self.created_subprocess.communicate.return_value = ("stdout", "")
114        ssh_object = ssh.Ssh(self.FAKE_IP,
115                             self.FAKE_SSH_USER,
116                             self.FAKE_SSH_PRIVATE_KEY_PATH,
117                             self.FAKE_EXTRA_ARGS_SSH)
118        self.assertEqual("stdout", ssh_object.Run("command"))
119        expected_cmd = (
120            "exec /usr/bin/ssh -i /fake/acloud_rea -o LogLevel=ERROR -o ControlPath=none "
121            "-o UserKnownHostsFile=/dev/null -o StrictHostKeyChecking=no "
122            "-o ProxyCommand='ssh fake_user@2.2.2.2 Server 22' "
123            "-l fake_user 1.1.1.1 command")
124        subprocess.Popen.assert_called_with(expected_cmd,
125                                            shell=True,
126                                            stderr=-2,
127                                            stdin=None,
128                                            stdout=-1,
129                                            universal_newlines=True)
130
131    def testScpPullFileCmd(self):
132        """Test scp pull file command."""
133        self.Patch(subprocess, "Popen", return_value=self.created_subprocess)
134        ssh_object = ssh.Ssh(self.FAKE_IP, self.FAKE_SSH_USER, self.FAKE_SSH_PRIVATE_KEY_PATH)
135        ssh_object.ScpPullFile("/tmp/test", "/tmp/test_1.log")
136        expected_cmd = (
137            "exec /usr/bin/scp -i /fake/acloud_rea -o LogLevel=ERROR -o ControlPath=none "
138            "-o UserKnownHostsFile=/dev/null -o StrictHostKeyChecking=no "
139            "fake_user@1.1.1.1:/tmp/test /tmp/test_1.log")
140        subprocess.Popen.assert_called_with(expected_cmd,
141                                            shell=True,
142                                            stderr=-2,
143                                            stdin=None,
144                                            stdout=-1,
145                                            universal_newlines=True)
146
147    def testScpPullFileCmdwithExtraArgs(self):
148        """Test scp pull file command."""
149        self.Patch(subprocess, "Popen", return_value=self.created_subprocess)
150        ssh_object = ssh.Ssh(self.FAKE_IP,
151                             self.FAKE_SSH_USER,
152                             self.FAKE_SSH_PRIVATE_KEY_PATH,
153                             self.FAKE_EXTRA_ARGS_SSH)
154        ssh_object.ScpPullFile("/tmp/test", "/tmp/test_1.log")
155        expected_cmd = (
156            "exec /usr/bin/scp -i /fake/acloud_rea -o LogLevel=ERROR -o ControlPath=none "
157            "-o UserKnownHostsFile=/dev/null -o StrictHostKeyChecking=no "
158            "-o ProxyCommand='ssh fake_user@2.2.2.2 Server 22' "
159            "fake_user@1.1.1.1:/tmp/test /tmp/test_1.log")
160        subprocess.Popen.assert_called_with(expected_cmd,
161                                            shell=True,
162                                            stderr=-2,
163                                            stdin=None,
164                                            stdout=-1,
165                                            universal_newlines=True)
166
167    def testScpPushFileCmd(self):
168        """Test scp push file command."""
169        self.Patch(subprocess, "Popen", return_value=self.created_subprocess)
170        ssh_object = ssh.Ssh(self.FAKE_IP, self.FAKE_SSH_USER, self.FAKE_SSH_PRIVATE_KEY_PATH)
171        ssh_object.ScpPushFile("/tmp/test", "/tmp/test_1.log")
172        expected_cmd = (
173            "exec /usr/bin/scp -i /fake/acloud_rea -o LogLevel=ERROR -o ControlPath=none "
174            "-o UserKnownHostsFile=/dev/null -o StrictHostKeyChecking=no "
175            "/tmp/test fake_user@1.1.1.1:/tmp/test_1.log")
176        subprocess.Popen.assert_called_with(expected_cmd,
177                                            shell=True,
178                                            stderr=-2,
179                                            stdin=None,
180                                            stdout=-1,
181                                            universal_newlines=True)
182
183    def testScpPushFileCmdwithExtraArgs(self):
184        """Test scp pull file command."""
185        self.Patch(subprocess, "Popen", return_value=self.created_subprocess)
186        ssh_object = ssh.Ssh(self.FAKE_IP,
187                             self.FAKE_SSH_USER,
188                             self.FAKE_SSH_PRIVATE_KEY_PATH,
189                             self.FAKE_EXTRA_ARGS_SSH)
190        ssh_object.ScpPushFile("/tmp/test", "/tmp/test_1.log")
191        expected_cmd = (
192            "exec /usr/bin/scp -i /fake/acloud_rea -o LogLevel=ERROR -o ControlPath=none "
193            "-o UserKnownHostsFile=/dev/null -o StrictHostKeyChecking=no "
194            "-o ProxyCommand='ssh fake_user@2.2.2.2 Server 22' "
195            "/tmp/test fake_user@1.1.1.1:/tmp/test_1.log")
196        subprocess.Popen.assert_called_with(expected_cmd,
197                                            shell=True,
198                                            stderr=-2,
199                                            stdin=None,
200                                            stdout=-1,
201                                            universal_newlines=True)
202
203    # pylint: disable=protected-access
204    def testIPAddress(self):
205        """Test IP class to get ip address."""
206        # Internal ip case.
207        ssh_object = ssh.Ssh(ip=ssh.IP(external="1.1.1.1", internal="10.1.1.1"),
208                             user=self.FAKE_SSH_USER,
209                             ssh_private_key_path=self.FAKE_SSH_PRIVATE_KEY_PATH,
210                             report_internal_ip=True)
211        expected_ip = "10.1.1.1"
212        self.assertEqual(ssh_object._ip, expected_ip)
213
214        # External ip case.
215        ssh_object = ssh.Ssh(ip=ssh.IP(external="1.1.1.1", internal="10.1.1.1"),
216                             user=self.FAKE_SSH_USER,
217                             ssh_private_key_path=self.FAKE_SSH_PRIVATE_KEY_PATH)
218        expected_ip = "1.1.1.1"
219        self.assertEqual(ssh_object._ip, expected_ip)
220
221        # Only one ip case.
222        ssh_object = ssh.Ssh(ip=ssh.IP(ip="1.1.1.1"),
223                             user=self.FAKE_SSH_USER,
224                             ssh_private_key_path=self.FAKE_SSH_PRIVATE_KEY_PATH)
225        expected_ip = "1.1.1.1"
226        self.assertEqual(ssh_object._ip, expected_ip)
227
228    def testWaitForSsh(self):
229        """Test WaitForSsh."""
230        ssh_object = ssh.Ssh(ip=self.FAKE_IP,
231                             user=self.FAKE_SSH_USER,
232                             ssh_private_key_path=self.FAKE_SSH_PRIVATE_KEY_PATH,
233                             report_internal_ip=self.FAKE_REPORT_INTERNAL_IP)
234        self.created_subprocess.returncode = -1
235        self.Patch(subprocess, "Popen", return_value=self.created_subprocess)
236        self.assertRaises(subprocess.CalledProcessError,
237                          ssh_object.WaitForSsh,
238                          timeout=1,
239                          max_retry=1)
240
241    def testSshCallWait(self):
242        """Test SshCallWait."""
243        self.Patch(subprocess, "Popen", return_value=self.created_subprocess)
244        self.Patch(threading, "Timer")
245        fake_cmd = "fake command"
246        ssh._SshCallWait(fake_cmd)
247        threading.Timer.assert_not_called()
248
249    def testSshCallWaitTimeout(self):
250        """Test SshCallWait with timeout."""
251        self.Patch(subprocess, "Popen", return_value=self.created_subprocess)
252        self.Patch(threading, "Timer")
253        fake_cmd = "fake command"
254        fake_timeout = 30
255        ssh._SshCallWait(fake_cmd, fake_timeout)
256        threading.Timer.assert_called_once()
257
258    def testSshCall(self):
259        """Test _SshCall."""
260        self.Patch(subprocess, "Popen", return_value=self.created_subprocess)
261        self.Patch(threading, "Timer")
262        fake_cmd = "fake command"
263        ssh._SshCall(fake_cmd)
264        threading.Timer.assert_not_called()
265
266    def testSshCallTimeout(self):
267        """Test SshCallWait with timeout."""
268        self.Patch(subprocess, "Popen", return_value=self.created_subprocess)
269        self.Patch(threading, "Timer")
270        fake_cmd = "fake command"
271        fake_timeout = 30
272        ssh._SshCall(fake_cmd, fake_timeout)
273        threading.Timer.assert_called_once()
274
275    def testSshLogOutput(self):
276        """Test _SshCall."""
277        self.Patch(subprocess, "Popen", return_value=self.created_subprocess)
278        self.Patch(threading, "Timer")
279        fake_cmd = "fake command"
280        ssh._SshLogOutput(fake_cmd)
281        threading.Timer.assert_not_called()
282
283        # Test with all kind of exceptions.
284        self.created_subprocess.returncode = 255
285        self.assertRaises(
286            errors.DeviceConnectionError, ssh._SshLogOutput, fake_cmd)
287
288        self.created_subprocess.returncode = -1
289        self.assertRaises(
290            subprocess.CalledProcessError, ssh._SshLogOutput, fake_cmd)
291
292        with mock.patch("sys.stderr", new=io.StringIO()):
293            self.created_subprocess.communicate = mock.MagicMock(
294                return_value=(constants.ERROR_MSG_VNC_NOT_SUPPORT, ''))
295            self.assertRaises(
296                errors.LaunchCVDFail, ssh._SshLogOutput, fake_cmd)
297
298        with mock.patch("sys.stderr", new=io.StringIO()):
299            self.created_subprocess.communicate = mock.MagicMock(
300                return_value=(constants.ERROR_MSG_WEBRTC_NOT_SUPPORT, ''))
301            self.assertRaises(
302                errors.LaunchCVDFail, ssh._SshLogOutput, fake_cmd)
303
304    def testSshLogOutputTimeout(self):
305        """Test SshCallWait with timeout."""
306        self.Patch(subprocess, "Popen", return_value=self.created_subprocess)
307        self.Patch(threading, "Timer")
308        fake_cmd = "fake command"
309        fake_timeout = 30
310        ssh._SshLogOutput(fake_cmd, fake_timeout)
311        threading.Timer.assert_called_once()
312
313    def testGetErrorMessage(self):
314        """Test _GetErrorMessage."""
315        # should return response
316        fake_output = """
317fetch_cvd E 10-25 09:45:44  1337  1337 build_api.cc:184] URL endpoint did not have json path: {
318fetch_cvd E 10-25 09:45:44  1337  1337 build_api.cc:184] 	"error" : "Failed to parse json.",
319fetch_cvd E 10-25 09:45:44  1337  1337 build_api.cc:184] 	"response" : "fake_error_response"
320fetch_cvd E 10-25 09:45:44  1337  1337 build_api.cc:184] }
321fetch_cvd E 10-25 09:45:44  1337  1337 fetch_cvd.cc:102] Unable to download."""
322        self.assertEqual(ssh._GetErrorMessage(fake_output), "fake_error_response")
323
324        # should return message only
325        fake_output = """
326fetch_cvd F 11-15 07:34:13  2169  2169 build_api.cc:164] Error fetching the artifacts
327fetch_cvd F 11-15 07:34:13  2169  2169 build_api.cc:164] 	"error" :
328fetch_cvd F 11-15 07:34:13  2169  2169 build_api.cc:164] 	{
329fetch_cvd F 11-15 07:34:13  2169  2169 build_api.cc:164] 		"code" : 500,
330fetch_cvd F 11-15 07:34:13  2169  2169 build_api.cc:164] 		"errors" :
331fetch_cvd F 11-15 07:34:13  2169  2169 build_api.cc:164] 		[
332fetch_cvd F 11-15 07:34:13  2169  2169 build_api.cc:164] 			{}
333fetch_cvd F 11-15 07:34:13  2169  2169 build_api.cc:164] 		],
334fetch_cvd F 11-15 07:34:13  2169  2169 build_api.cc:164] 		"message" : "Unknown Error.",
335fetch_cvd F 11-15 07:34:13  2169  2169 build_api.cc:164] 		"status" : "UNKNOWN"
336fetch_cvd F 11-15 07:34:13  2169  2169 build_api.cc:164] 	}
337fetch_cvd F 11-15 07:34:13  2169  2169 build_api.cc:164] }", and code was 500Fail! (320s)"""
338        self.assertEqual(ssh._GetErrorMessage(fake_output), "Unknown Error.")
339
340        # should output last 10 line
341        fake_output = """
342fetch_cvd F 11-15 07:34:13  2169  2169 build_api.cc:164] Error fetching the artifacts of {
343fetch_cvd F 11-15 07:34:13  2169  2169 build_api.cc:164] 	"error" :
344fetch_cvd F 11-15 07:34:13  2169  2169 build_api.cc:164] 	{
345fetch_cvd F 11-15 07:34:13  2169  2169 build_api.cc:164] 		"code" : 500,
346fetch_cvd F 11-15 07:34:13  2169  2169 build_api.cc:164] 		"errors" :
347fetch_cvd F 11-15 07:34:13  2169  2169 build_api.cc:164] 		[
348fetch_cvd F 11-15 07:34:13  2169  2169 build_api.cc:164] 			{}
349fetch_cvd F 11-15 07:34:13  2169  2169 build_api.cc:164] 		],
350fetch_cvd F 11-15 07:34:13  2169  2169 build_api.cc:164] 		"status" : "UNKNOWN"
351fetch_cvd F 11-15 07:34:13  2169  2169 build_api.cc:164] 	}
352fetch_cvd F 11-15 07:34:13  2169  2169 build_api.cc:164] }", and code was 500Fail! (320s)"""
353        self.assertEqual(ssh._GetErrorMessage(fake_output), "\n".join(
354            fake_output.splitlines()[-10::]))
355
356    def testFilterUnusedContent(self):
357        """Test _FilterUnusedContent."""
358        # should remove html, !, title, span, a, p, b, style, ins, code, \n
359        fake_content = ("<!DOCTYPE html><html lang=en>\\n<meta charset=utf-8>"
360                        "<title>Error</title>\\n<style>*{padding:0}html}</style>"
361                        "<a href=//www.google.com/><span id=logo></span></a>"
362                        "<p><b>404.</b> <ins>That\u2019s an error.</ins><p>"
363                        "The requested URL was not found on this server <code>"
364                        "url/id</code> <ins>That\u2019s all we know.</ins>\\n")
365        expected = (" Error 404. That’s an error.The requested URL was not"
366                    " found on this server url/id That’s all we know. ")
367        self.assertEqual(ssh._FilterUnusedContent(fake_content), expected)
368
369
370if __name__ == "__main__":
371    unittest.main()
372