• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2016 - The Android Open Source Project
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15import collections
16import logging
17import os
18import re
19import shutil
20import tempfile
21import threading
22import time
23import uuid
24
25from acts.controllers.utils_lib import host_utils
26from acts.controllers.utils_lib.ssh import formatter
27from acts.libs.proc import job
28
29
30class Error(Exception):
31    """An error occured during an ssh operation."""
32
33
34class CommandError(Exception):
35    """An error occured with the command.
36
37    Attributes:
38        result: The results of the ssh command that had the error.
39    """
40
41    def __init__(self, result):
42        """
43        Args:
44            result: The result of the ssh command that created the problem.
45        """
46        self.result = result
47
48    def __str__(self):
49        return 'cmd: %s\nstdout: %s\nstderr: %s' % (
50            self.result.command, self.result.stdout, self.result.stderr)
51
52
53_Tunnel = collections.namedtuple('_Tunnel',
54                                 ['local_port', 'remote_port', 'proc'])
55
56
57class SshConnection(object):
58    """Provides a connection to a remote machine through ssh.
59
60    Provides the ability to connect to a remote machine and execute a command
61    on it. The connection will try to establish a persistent connection When
62    a command is run. If the persistent connection fails it will attempt
63    to connect normally.
64    """
65
66    @property
67    def socket_path(self):
68        """Returns: The os path to the master socket file."""
69        return os.path.join(self._master_ssh_tempdir, 'socket')
70
71    def __init__(self, settings):
72        """
73        Args:
74            settings: The ssh settings to use for this conneciton.
75            formatter: The object that will handle formatting ssh command
76                       for use with the background job.
77        """
78        self._settings = settings
79        self._formatter = formatter.SshFormatter()
80        self._lock = threading.Lock()
81        self._master_ssh_proc = None
82        self._master_ssh_tempdir = None
83        self._tunnels = list()
84
85    def __del__(self):
86        self.close()
87
88    def setup_master_ssh(self, timeout_seconds=5):
89        """Sets up the master ssh connection.
90
91        Sets up the inital master ssh connection if it has not already been
92        started.
93
94        Args:
95            timeout_seconds: The time to wait for the master ssh connection to be made.
96
97        Raises:
98            Error: When setting up the master ssh connection fails.
99        """
100        with self._lock:
101            if self._master_ssh_proc is not None:
102                socket_path = self.socket_path
103                if (not os.path.exists(socket_path) or
104                        self._master_ssh_proc.poll() is not None):
105                    logging.debug('Master ssh connection to %s is down.',
106                                  self._settings.hostname)
107                    self._cleanup_master_ssh()
108
109            if self._master_ssh_proc is None:
110                # Create a shared socket in a temp location.
111                self._master_ssh_tempdir = tempfile.mkdtemp(
112                    prefix='ssh-master')
113
114                # Setup flags and options for running the master ssh
115                # -N: Do not execute a remote command.
116                # ControlMaster: Spawn a master connection.
117                # ControlPath: The master connection socket path.
118                extra_flags = {'-N': None}
119                extra_options = {
120                    'ControlMaster': True,
121                    'ControlPath': self.socket_path,
122                    'BatchMode': True
123                }
124
125                # Construct the command and start it.
126                master_cmd = self._formatter.format_ssh_local_command(
127                    self._settings,
128                    extra_flags=extra_flags,
129                    extra_options=extra_options)
130                logging.info('Starting master ssh connection to %s',
131                             self._settings.hostname)
132                self._master_ssh_proc = job.run_async(master_cmd)
133
134                end_time = time.time() + timeout_seconds
135
136                while time.time() < end_time:
137                    if os.path.exists(self.socket_path):
138                        break
139                    time.sleep(.2)
140                else:
141                    self._cleanup_master_ssh()
142                    raise Error('Master ssh connection timed out.')
143
144    def run(self,
145            command,
146            timeout=3600,
147            ignore_status=False,
148            env=None,
149            io_encoding='utf-8'):
150        """Runs a remote command over ssh.
151
152        Will ssh to a remote host and run a command. This method will
153        block until the remote command is finished.
154
155        Args:
156            command: The command to execute over ssh. Can be either a string
157                     or a list.
158            timeout: number seconds to wait for command to finish.
159            ignore_status: bool True to ignore the exit code of the remote
160                           subprocess.  Note that if you do ignore status codes,
161                           you should handle non-zero exit codes explicitly.
162            env: dict enviroment variables to setup on the remote host.
163            io_encoding: str unicode encoding of command output.
164
165        Returns:
166            A job.Result containing the results of the ssh command.
167
168        Raises:
169            job.TimeoutError: When the remote command took to long to execute.
170            Error: When the ssh connection failed to be created.
171            CommandError: Ssh worked, but the command had an error executing.
172        """
173        if env is None:
174            env = {}
175
176        try:
177            self.setup_master_ssh(self._settings.connect_timeout)
178        except Error:
179            logging.warning('Failed to create master ssh connection, using '
180                            'normal ssh connection.')
181
182        extra_options = {'BatchMode': True}
183        if self._master_ssh_proc:
184            extra_options['ControlPath'] = self.socket_path
185
186        identifier = str(uuid.uuid4())
187        full_command = 'echo "CONNECTED: %s"; %s' % (identifier, command)
188
189        terminal_command = self._formatter.format_command(
190            full_command, env, self._settings, extra_options=extra_options)
191
192        dns_retry_count = 2
193        while True:
194            result = job.run(terminal_command,
195                             ignore_status=True,
196                             timeout=timeout)
197            output = result.stdout
198
199            # Check for a connected message to prevent false negatives.
200            valid_connection = re.search(
201                '^CONNECTED: %s' % identifier, output, flags=re.MULTILINE)
202            if valid_connection:
203                # Remove the first line that contains the connect message.
204                line_index = output.find('\n')
205                real_output = output[line_index + 1:].encode(result._encoding)
206
207                result = job.Result(
208                    command=result.command,
209                    stdout=real_output,
210                    stderr=result._raw_stderr,
211                    exit_status=result.exit_status,
212                    duration=result.duration,
213                    did_timeout=result.did_timeout,
214                    encoding=result._encoding)
215                if result.exit_status:
216                    raise job.Error(result)
217                return result
218
219            error_string = result.stderr
220
221            had_dns_failure = (result.exit_status == 255 and re.search(
222                r'^ssh: .*: Name or service not known',
223                error_string,
224                flags=re.MULTILINE))
225            if had_dns_failure:
226                dns_retry_count -= 1
227                if not dns_retry_count:
228                    raise Error('DNS failed to find host.', result)
229                logging.debug('Failed to connecto to host, retrying...')
230            else:
231                break
232
233        had_timeout = re.search(
234            r'^ssh: connect to host .* port .*: '
235            r'Connection timed out\r$',
236            error_string,
237            flags=re.MULTILINE)
238        if had_timeout:
239            raise Error('Ssh timed out.', result)
240
241        permission_denied = 'Permission denied' in error_string
242        if permission_denied:
243            raise Error('Permission denied.', result)
244
245        unknown_host = re.search(
246            r'ssh: Could not resolve hostname .*: '
247            r'Name or service not known',
248            error_string,
249            flags=re.MULTILINE)
250        if unknown_host:
251            raise Error('Unknown host.', result)
252
253        raise Error('The job failed for unkown reasons.', result)
254
255    def run_async(self, command, env=None):
256        """Starts up a background command over ssh.
257
258        Will ssh to a remote host and startup a command. This method will
259        block until there is confirmation that the remote command has started.
260
261        Args:
262            command: The command to execute over ssh. Can be either a string
263                     or a list.
264            env: A dictonary of enviroment variables to setup on the remote
265                 host.
266
267        Returns:
268            The result of the command to launch the background job.
269
270        Raises:
271            CmdTimeoutError: When the remote command took to long to execute.
272            SshTimeoutError: When the connection took to long to established.
273            SshPermissionDeniedError: When permission is not allowed on the
274                                      remote host.
275        """
276        command = '(%s) < /dev/null > /dev/null 2>&1 & echo -n $!' % command
277        result = self.run(command, env=env)
278        return result
279
280    def close(self):
281        """Clean up open connections to remote host."""
282        self._cleanup_master_ssh()
283        while self._tunnels:
284            self.close_ssh_tunnel(self._tunnels[0].local_port)
285
286    def _cleanup_master_ssh(self):
287        """
288        Release all resources (process, temporary directory) used by an active
289        master SSH connection.
290        """
291        # If a master SSH connection is running, kill it.
292        if self._master_ssh_proc is not None:
293            logging.debug('Nuking master_ssh_job.')
294            self._master_ssh_proc.kill()
295            self._master_ssh_proc.wait()
296            self._master_ssh_proc = None
297
298        # Remove the temporary directory for the master SSH socket.
299        if self._master_ssh_tempdir is not None:
300            logging.debug('Cleaning master_ssh_tempdir.')
301            shutil.rmtree(self._master_ssh_tempdir)
302            self._master_ssh_tempdir = None
303
304    def create_ssh_tunnel(self, port, local_port=None):
305        """Create an ssh tunnel from local_port to port.
306
307        This securely forwards traffic from local_port on this machine to the
308        remote SSH host at port.
309
310        Args:
311            port: remote port on the host.
312            local_port: local forwarding port, or None to pick an available
313                        port.
314
315        Returns:
316            the created tunnel process.
317        """
318        if local_port is None:
319            local_port = host_utils.get_available_host_port()
320        else:
321            for tunnel in self._tunnels:
322                if tunnel.remote_port == port:
323                    return tunnel.local_port
324
325        extra_flags = {
326            '-n': None,  # Read from /dev/null for stdin
327            '-N': None,  # Do not execute a remote command
328            '-q': None,  # Suppress warnings and diagnostic commands
329            '-L': '%d:localhost:%d' % (local_port, port),
330        }
331        extra_options = dict()
332        if self._master_ssh_proc:
333            extra_options['ControlPath'] = self.socket_path
334        tunnel_cmd = self._formatter.format_ssh_local_command(
335            self._settings,
336            extra_flags=extra_flags,
337            extra_options=extra_options)
338        logging.debug('Full tunnel command: %s', tunnel_cmd)
339        # Exec the ssh process directly so that when we deliver signals, we
340        # deliver them straight to the child process.
341        tunnel_proc = job.run_async(tunnel_cmd)
342        logging.debug('Started ssh tunnel, local = %d'
343                      ' remote = %d, pid = %d', local_port, port,
344                      tunnel_proc.pid)
345        self._tunnels.append(_Tunnel(local_port, port, tunnel_proc))
346        return local_port
347
348    def close_ssh_tunnel(self, local_port):
349        """Close a previously created ssh tunnel of a TCP port.
350
351        Args:
352            local_port: int port on localhost previously forwarded to the remote
353                        host.
354
355        Returns:
356            integer port number this port was forwarded to on the remote host or
357            None if no tunnel was found.
358        """
359        idx = None
360        for i, tunnel in enumerate(self._tunnels):
361            if tunnel.local_port == local_port:
362                idx = i
363                break
364        if idx is not None:
365            tunnel = self._tunnels.pop(idx)
366            tunnel.proc.kill()
367            tunnel.proc.wait()
368            return tunnel.remote_port
369        return None
370
371    def send_file(self, local_path, remote_path):
372        """Send a file from the local host to the remote host.
373
374        Args:
375            local_path: string path of file to send on local host.
376            remote_path: string path to copy file to on remote host.
377        """
378        # TODO: This may belong somewhere else: b/32572515
379        user_host = self._formatter.format_host_name(self._settings)
380        job.run('scp %s %s:%s' % (local_path, user_host, remote_path))
381
382    def find_free_port(self, interface_name='localhost'):
383        """Find a unused port on the remote host.
384
385        Note that this method is inherently racy, since it is impossible
386        to promise that the remote port will remain free.
387
388        Args:
389            interface_name: string name of interface to check whether a
390                            port is used against.
391
392        Returns:
393            integer port number on remote interface that was free.
394        """
395        # TODO: This may belong somewhere else: b/3257251
396        free_port_cmd = (
397            'python -c "import socket; s=socket.socket(); '
398            's.bind((\'%s\', 0)); print(s.getsockname()[1]); s.close()"'
399        ) % interface_name
400        port = int(self.run(free_port_cmd).stdout)
401        # Yield to the os to ensure the port gets cleaned up.
402        time.sleep(0.001)
403        return port
404