• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2016 - The Android Open Source Project
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
7#     http://www.apache.org/licenses/LICENSE-2.0
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
15import collections
16import os
17import re
18import shutil
19import tempfile
20import threading
21import time
22import uuid
24from acts import logger
25from acts.controllers.utils_lib import host_utils
26from acts.controllers.utils_lib.ssh import formatter
27from acts.libs.proc import job
30class Error(Exception):
31    """An error occurred during an ssh operation."""
34class CommandError(Exception):
35    """An error occurred with the command.
37    Attributes:
38        result: The results of the ssh command that had the error.
39    """
41    def __init__(self, result):
42        """
43        Args:
44            result: The result of the ssh command that created the problem.
45        """
46        self.result = result
48    def __str__(self):
49        return 'cmd: %s\nstdout: %s\nstderr: %s' % (
50            self.result.command, self.result.stdout, self.result.stderr)
53_Tunnel = collections.namedtuple('_Tunnel',
54                                 ['local_port', 'remote_port', 'proc'])
57class SshConnection(object):
58    """Provides a connection to a remote machine through ssh.
60    Provides the ability to connect to a remote machine and execute a command
61    on it. The connection will try to establish a persistent connection When
62    a command is run. If the persistent connection fails it will attempt
63    to connect normally.
64    """
66    @property
67    def socket_path(self):
68        """Returns: The os path to the master socket file."""
69        return os.path.join(self._master_ssh_tempdir, 'socket')
71    def __init__(self, settings):
72        """
73        Args:
74            settings: The ssh settings to use for this connection.
75            formatter: The object that will handle formatting ssh command
76                       for use with the background job.
77        """
78        self._settings = settings
79        self._formatter = formatter.SshFormatter()
80        self._lock = threading.Lock()
81        self._master_ssh_proc = None
82        self._master_ssh_tempdir = None
83        self._tunnels = list()
85        def log_line(msg):
86            return '[SshConnection | %s] %s' % (self._settings.hostname, msg)
88        self.log = logger.create_logger(log_line)
90    def __enter__(self):
91        return self
93    def __exit__(self, _, __, ___):
94        self.close()
96    def __del__(self):
97        self.close()
99    def setup_master_ssh(self, timeout_seconds=5):
100        """Sets up the master ssh connection.
102        Sets up the initial master ssh connection if it has not already been
103        started.
105        Args:
106            timeout_seconds: The time to wait for the master ssh connection to
107            be made.
109        Raises:
110            Error: When setting up the master ssh connection fails.
111        """
112        with self._lock:
113            if self._master_ssh_proc is not None:
114                socket_path = self.socket_path
115                if (not os.path.exists(socket_path)
116                        or self._master_ssh_proc.poll() is not None):
117                    self.log.debug('Master ssh connection to %s is down.',
118                                   self._settings.hostname)
119                    self._cleanup_master_ssh()
121            if self._master_ssh_proc is None:
122                # Create a shared socket in a temp location.
123                self._master_ssh_tempdir = tempfile.mkdtemp(prefix='ssh-master')
125                # Setup flags and options for running the master ssh
126                # -N: Do not execute a remote command.
127                # ControlMaster: Spawn a master connection.
128                # ControlPath: The master connection socket path.
129                extra_flags = {'-N': None}
130                extra_options = {
131                    'ControlMaster': True,
132                    'ControlPath': self.socket_path,
133                    'BatchMode': True
134                }
136                # Construct the command and start it.
137                master_cmd = self._formatter.format_ssh_local_command(
138                    self._settings,
139                    extra_flags=extra_flags,
140                    extra_options=extra_options)
141                self.log.info('Starting master ssh connection.')
142                self._master_ssh_proc = job.run_async(master_cmd)
144                end_time = time.time() + timeout_seconds
146                while time.time() < end_time:
147                    if os.path.exists(self.socket_path):
148                        break
149                    time.sleep(.2)
150                else:
151                    self._cleanup_master_ssh()
152                    raise Error('Master ssh connection timed out.')
154    def run(self,
155            command,
156            timeout=3600,
157            ignore_status=False,
158            env=None,
159            io_encoding='utf-8',
160            attempts=2):
161        """Runs a remote command over ssh.
163        Will ssh to a remote host and run a command. This method will
164        block until the remote command is finished.
166        Args:
167            command: The command to execute over ssh. Can be either a string
168                     or a list.
169            timeout: number seconds to wait for command to finish.
170            ignore_status: bool True to ignore the exit code of the remote
171                           subprocess.  Note that if you do ignore status codes,
172                           you should handle non-zero exit codes explicitly.
173            env: dict environment variables to setup on the remote host.
174            io_encoding: str unicode encoding of command output.
175            attempts: Number of attempts before giving up on command failures.
177        Returns:
178            A job.Result containing the results of the ssh command.
180        Raises:
181            job.TimeoutError: When the remote command took to long to execute.
182            Error: When the ssh connection failed to be created.
183            CommandError: Ssh worked, but the command had an error executing.
184        """
185        if attempts == 0:
186            return None
187        if env is None:
188            env = {}
190        try:
191            self.setup_master_ssh(self._settings.connect_timeout)
192        except Error:
193            self.log.warning('Failed to create master ssh connection, using '
194                             'normal ssh connection.')
196        extra_options = {'BatchMode': True}
197        if self._master_ssh_proc:
198            extra_options['ControlPath'] = self.socket_path
200        identifier = str(uuid.uuid4())
201        full_command = 'echo "CONNECTED: %s"; %s' % (identifier, command)
203        terminal_command = self._formatter.format_command(
204            full_command, env, self._settings, extra_options=extra_options)
206        dns_retry_count = 2
207        while True:
208            result = job.run(
209                terminal_command,
210                ignore_status=True,
211                timeout=timeout,
212                io_encoding=io_encoding)
213            output = result.stdout
215            # Check for a connected message to prevent false negatives.
216            valid_connection = re.search(
217                '^CONNECTED: %s' % identifier, output, flags=re.MULTILINE)
218            if valid_connection:
219                # Remove the first line that contains the connect message.
220                line_index = output.find('\n') + 1
221                if line_index == 0:
222                    line_index = len(output)
223                real_output = output[line_index:].encode(io_encoding)
225                result = job.Result(
226                    command=result.command,
227                    stdout=real_output,
228                    stderr=result._raw_stderr,
229                    exit_status=result.exit_status,
230                    duration=result.duration,
231                    did_timeout=result.did_timeout,
232                    encoding=io_encoding)
233                if result.exit_status and not ignore_status:
234                    raise job.Error(result)
235                return result
237            error_string = result.stderr
239            had_dns_failure = (result.exit_status == 255 and re.search(
240                r'^ssh: .*: Name or service not known',
241                error_string,
242                flags=re.MULTILINE))
243            if had_dns_failure:
244                dns_retry_count -= 1
245                if not dns_retry_count:
246                    raise Error('DNS failed to find host.', result)
247                self.log.debug('Failed to connect to host, retrying...')
248            else:
249                break
251        had_timeout = re.search(
252            r'^ssh: connect to host .* port .*: '
253            r'Connection timed out\r$',
254            error_string,
255            flags=re.MULTILINE)
256        if had_timeout:
257            raise Error('Ssh timed out.', result)
259        permission_denied = 'Permission denied' in error_string
260        if permission_denied:
261            raise Error('Permission denied.', result)
263        unknown_host = re.search(
264            r'ssh: Could not resolve hostname .*: '
265            r'Name or service not known',
266            error_string,
267            flags=re.MULTILINE)
268        if unknown_host:
269            raise Error('Unknown host.', result)
271        self.log.error('An unknown error has occurred. Job result: %s' % result)
272        ping_output = job.run(
273            'ping %s -c 3 -w 1' % self._settings.hostname, ignore_status=True)
274        self.log.error('Ping result: %s' % ping_output)
275        if attempts > 1:
276            self._cleanup_master_ssh()
277            self.run(command, timeout, ignore_status, env, io_encoding,
278                     attempts - 1)
279        raise Error('The job failed for unknown reasons.', result)
281    def run_async(self, command, env=None):
282        """Starts up a background command over ssh.
284        Will ssh to a remote host and startup a command. This method will
285        block until there is confirmation that the remote command has started.
287        Args:
288            command: The command to execute over ssh. Can be either a string
289                     or a list.
290            env: A dictonary of environment variables to setup on the remote
291                 host.
293        Returns:
294            The result of the command to launch the background job.
296        Raises:
297            CmdTimeoutError: When the remote command took to long to execute.
298            SshTimeoutError: When the connection took to long to established.
299            SshPermissionDeniedError: When permission is not allowed on the
300                                      remote host.
301        """
302        command = '(%s) < /dev/null > /dev/null 2>&1 & echo -n $!' % command
303        result = self.run(command, env=env)
304        return result
306    def close(self):
307        """Clean up open connections to remote host."""
308        self._cleanup_master_ssh()
309        while self._tunnels:
310            self.close_ssh_tunnel(self._tunnels[0].local_port)
312    def _cleanup_master_ssh(self):
313        """
314        Release all resources (process, temporary directory) used by an active
315        master SSH connection.
316        """
317        # If a master SSH connection is running, kill it.
318        if self._master_ssh_proc is not None:
319            self.log.debug('Nuking master_ssh_job.')
320            self._master_ssh_proc.kill()
321            self._master_ssh_proc.wait()
322            self._master_ssh_proc = None
324        # Remove the temporary directory for the master SSH socket.
325        if self._master_ssh_tempdir is not None:
326            self.log.debug('Cleaning master_ssh_tempdir.')
327            shutil.rmtree(self._master_ssh_tempdir)
328            self._master_ssh_tempdir = None
330    def create_ssh_tunnel(self, port, local_port=None):
331        """Create an ssh tunnel from local_port to port.
333        This securely forwards traffic from local_port on this machine to the
334        remote SSH host at port.
336        Args:
337            port: remote port on the host.
338            local_port: local forwarding port, or None to pick an available
339                        port.
341        Returns:
342            the created tunnel process.
343        """
344        if not local_port:
345            local_port = host_utils.get_available_host_port()
346        else:
347            for tunnel in self._tunnels:
348                if tunnel.remote_port == port:
349                    return tunnel.local_port
351        extra_flags = {
352            '-n': None,  # Read from /dev/null for stdin
353            '-N': None,  # Do not execute a remote command
354            '-q': None,  # Suppress warnings and diagnostic commands
355            '-L': '%d:localhost:%d' % (local_port, port),
356        }
357        extra_options = dict()
358        if self._master_ssh_proc:
359            extra_options['ControlPath'] = self.socket_path
360        tunnel_cmd = self._formatter.format_ssh_local_command(
361            self._settings,
362            extra_flags=extra_flags,
363            extra_options=extra_options)
364        self.log.debug('Full tunnel command: %s', tunnel_cmd)
365        # Exec the ssh process directly so that when we deliver signals, we
366        # deliver them straight to the child process.
367        tunnel_proc = job.run_async(tunnel_cmd)
368        self.log.debug('Started ssh tunnel, local = %d remote = %d, pid = %d',
369                       local_port, port, tunnel_proc.pid)
370        self._tunnels.append(_Tunnel(local_port, port, tunnel_proc))
371        return local_port
373    def close_ssh_tunnel(self, local_port):
374        """Close a previously created ssh tunnel of a TCP port.
376        Args:
377            local_port: int port on localhost previously forwarded to the remote
378                        host.
380        Returns:
381            integer port number this port was forwarded to on the remote host or
382            None if no tunnel was found.
383        """
384        idx = None
385        for i, tunnel in enumerate(self._tunnels):
386            if tunnel.local_port == local_port:
387                idx = i
388                break
389        if idx is not None:
390            tunnel = self._tunnels.pop(idx)
391            tunnel.proc.kill()
392            tunnel.proc.wait()
393            return tunnel.remote_port
394        return None
396    def send_file(self, local_path, remote_path, ignore_status=False):
397        """Send a file from the local host to the remote host.
399        Args:
400            local_path: string path of file to send on local host.
401            remote_path: string path to copy file to on remote host.
402            ignore_status: Whether or not to ignore the command's exit_status.
403        """
404        # TODO: This may belong somewhere else: b/32572515
405        user_host = self._formatter.format_host_name(self._settings)
406        job.run(
407            'scp %s %s:%s' % (local_path, user_host, remote_path),
408            ignore_status=ignore_status)
410    def pull_file(self, local_path, remote_path, ignore_status=False):
411        """Send a file from remote host to local host
413        Args:
414            local_path: string path of file to recv on local host
415            remote_path: string path to copy file from on remote host.
416            ignore_status: Whether or not to ignore the command's exit_status.
417        """
418        user_host = self._formatter.format_host_name(self._settings)
419        job.run(
420            'scp %s:%s %s' % (user_host, remote_path, local_path),
421            ignore_status=ignore_status)
423    def find_free_port(self, interface_name='localhost'):
424        """Find a unused port on the remote host.
426        Note that this method is inherently racy, since it is impossible
427        to promise that the remote port will remain free.
429        Args:
430            interface_name: string name of interface to check whether a
431                            port is used against.
433        Returns:
434            integer port number on remote interface that was free.
435        """
436        # TODO: This may belong somewhere else: b/3257251
437        free_port_cmd = (
438            'python -c "import socket; s=socket.socket(); '
439            's.bind((\'%s\', 0)); print(s.getsockname()[1]); s.close()"'
440        ) % interface_name
441        port = int(self.run(free_port_cmd).stdout)
442        # Yield to the os to ensure the port gets cleaned up.
443        time.sleep(0.001)
444        return port