1# Copyright 2016 - The Android Open Source Project 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14 15import collections 16import os 17import re 18import shutil 19import tempfile 20import threading 21import time 22import uuid 23 24from acts import logger 25from acts.controllers.utils_lib import host_utils 26from acts.controllers.utils_lib.ssh import formatter 27from acts.libs.proc import job 28 29 30class Error(Exception): 31 """An error occurred during an ssh operation.""" 32 33 34class CommandError(Exception): 35 """An error occurred with the command. 36 37 Attributes: 38 result: The results of the ssh command that had the error. 39 """ 40 41 def __init__(self, result): 42 """ 43 Args: 44 result: The result of the ssh command that created the problem. 45 """ 46 self.result = result 47 48 def __str__(self): 49 return 'cmd: %s\nstdout: %s\nstderr: %s' % ( 50 self.result.command, self.result.stdout, self.result.stderr) 51 52 53_Tunnel = collections.namedtuple('_Tunnel', 54 ['local_port', 'remote_port', 'proc']) 55 56 57class SshConnection(object): 58 """Provides a connection to a remote machine through ssh. 59 60 Provides the ability to connect to a remote machine and execute a command 61 on it. The connection will try to establish a persistent connection When 62 a command is run. If the persistent connection fails it will attempt 63 to connect normally. 64 """ 65 66 @property 67 def socket_path(self): 68 """Returns: The os path to the master socket file.""" 69 return os.path.join(self._master_ssh_tempdir, 'socket') 70 71 def __init__(self, settings): 72 """ 73 Args: 74 settings: The ssh settings to use for this connection. 75 formatter: The object that will handle formatting ssh command 76 for use with the background job. 77 """ 78 self._settings = settings 79 self._formatter = formatter.SshFormatter() 80 self._lock = threading.Lock() 81 self._master_ssh_proc = None 82 self._master_ssh_tempdir = None 83 self._tunnels = list() 84 85 def log_line(msg): 86 return '[SshConnection | %s] %s' % (self._settings.hostname, msg) 87 88 self.log = logger.create_logger(log_line) 89 90 def __enter__(self): 91 return self 92 93 def __exit__(self, _, __, ___): 94 self.close() 95 96 def __del__(self): 97 self.close() 98 99 def setup_master_ssh(self, timeout_seconds=5): 100 """Sets up the master ssh connection. 101 102 Sets up the initial master ssh connection if it has not already been 103 started. 104 105 Args: 106 timeout_seconds: The time to wait for the master ssh connection to 107 be made. 108 109 Raises: 110 Error: When setting up the master ssh connection fails. 111 """ 112 with self._lock: 113 if self._master_ssh_proc is not None: 114 socket_path = self.socket_path 115 if (not os.path.exists(socket_path) 116 or self._master_ssh_proc.poll() is not None): 117 self.log.debug('Master ssh connection to %s is down.', 118 self._settings.hostname) 119 self._cleanup_master_ssh() 120 121 if self._master_ssh_proc is None: 122 # Create a shared socket in a temp location. 123 self._master_ssh_tempdir = tempfile.mkdtemp(prefix='ssh-master') 124 125 # Setup flags and options for running the master ssh 126 # -N: Do not execute a remote command. 127 # ControlMaster: Spawn a master connection. 128 # ControlPath: The master connection socket path. 129 extra_flags = {'-N': None} 130 extra_options = { 131 'ControlMaster': True, 132 'ControlPath': self.socket_path, 133 'BatchMode': True 134 } 135 136 # Construct the command and start it. 137 master_cmd = self._formatter.format_ssh_local_command( 138 self._settings, 139 extra_flags=extra_flags, 140 extra_options=extra_options) 141 self.log.info('Starting master ssh connection.') 142 self._master_ssh_proc = job.run_async(master_cmd) 143 144 end_time = time.time() + timeout_seconds 145 146 while time.time() < end_time: 147 if os.path.exists(self.socket_path): 148 break 149 time.sleep(.2) 150 else: 151 self._cleanup_master_ssh() 152 raise Error('Master ssh connection timed out.') 153 154 def run(self, 155 command, 156 timeout=3600, 157 ignore_status=False, 158 env=None, 159 io_encoding='utf-8', 160 attempts=2): 161 """Runs a remote command over ssh. 162 163 Will ssh to a remote host and run a command. This method will 164 block until the remote command is finished. 165 166 Args: 167 command: The command to execute over ssh. Can be either a string 168 or a list. 169 timeout: number seconds to wait for command to finish. 170 ignore_status: bool True to ignore the exit code of the remote 171 subprocess. Note that if you do ignore status codes, 172 you should handle non-zero exit codes explicitly. 173 env: dict environment variables to setup on the remote host. 174 io_encoding: str unicode encoding of command output. 175 attempts: Number of attempts before giving up on command failures. 176 177 Returns: 178 A job.Result containing the results of the ssh command. 179 180 Raises: 181 job.TimeoutError: When the remote command took to long to execute. 182 Error: When the ssh connection failed to be created. 183 CommandError: Ssh worked, but the command had an error executing. 184 """ 185 if attempts == 0: 186 return None 187 if env is None: 188 env = {} 189 190 try: 191 self.setup_master_ssh(self._settings.connect_timeout) 192 except Error: 193 self.log.warning('Failed to create master ssh connection, using ' 194 'normal ssh connection.') 195 196 extra_options = {'BatchMode': True} 197 if self._master_ssh_proc: 198 extra_options['ControlPath'] = self.socket_path 199 200 identifier = str(uuid.uuid4()) 201 full_command = 'echo "CONNECTED: %s"; %s' % (identifier, command) 202 203 terminal_command = self._formatter.format_command( 204 full_command, env, self._settings, extra_options=extra_options) 205 206 dns_retry_count = 2 207 while True: 208 result = job.run( 209 terminal_command, 210 ignore_status=True, 211 timeout=timeout, 212 io_encoding=io_encoding) 213 output = result.stdout 214 215 # Check for a connected message to prevent false negatives. 216 valid_connection = re.search( 217 '^CONNECTED: %s' % identifier, output, flags=re.MULTILINE) 218 if valid_connection: 219 # Remove the first line that contains the connect message. 220 line_index = output.find('\n') + 1 221 if line_index == 0: 222 line_index = len(output) 223 real_output = output[line_index:].encode(io_encoding) 224 225 result = job.Result( 226 command=result.command, 227 stdout=real_output, 228 stderr=result._raw_stderr, 229 exit_status=result.exit_status, 230 duration=result.duration, 231 did_timeout=result.did_timeout, 232 encoding=io_encoding) 233 if result.exit_status and not ignore_status: 234 raise job.Error(result) 235 return result 236 237 error_string = result.stderr 238 239 had_dns_failure = (result.exit_status == 255 and re.search( 240 r'^ssh: .*: Name or service not known', 241 error_string, 242 flags=re.MULTILINE)) 243 if had_dns_failure: 244 dns_retry_count -= 1 245 if not dns_retry_count: 246 raise Error('DNS failed to find host.', result) 247 self.log.debug('Failed to connect to host, retrying...') 248 else: 249 break 250 251 had_timeout = re.search( 252 r'^ssh: connect to host .* port .*: ' 253 r'Connection timed out\r$', 254 error_string, 255 flags=re.MULTILINE) 256 if had_timeout: 257 raise Error('Ssh timed out.', result) 258 259 permission_denied = 'Permission denied' in error_string 260 if permission_denied: 261 raise Error('Permission denied.', result) 262 263 unknown_host = re.search( 264 r'ssh: Could not resolve hostname .*: ' 265 r'Name or service not known', 266 error_string, 267 flags=re.MULTILINE) 268 if unknown_host: 269 raise Error('Unknown host.', result) 270 271 self.log.error('An unknown error has occurred. Job result: %s' % result) 272 ping_output = job.run( 273 'ping %s -c 3 -w 1' % self._settings.hostname, ignore_status=True) 274 self.log.error('Ping result: %s' % ping_output) 275 if attempts > 1: 276 self._cleanup_master_ssh() 277 self.run(command, timeout, ignore_status, env, io_encoding, 278 attempts - 1) 279 raise Error('The job failed for unknown reasons.', result) 280 281 def run_async(self, command, env=None): 282 """Starts up a background command over ssh. 283 284 Will ssh to a remote host and startup a command. This method will 285 block until there is confirmation that the remote command has started. 286 287 Args: 288 command: The command to execute over ssh. Can be either a string 289 or a list. 290 env: A dictonary of environment variables to setup on the remote 291 host. 292 293 Returns: 294 The result of the command to launch the background job. 295 296 Raises: 297 CmdTimeoutError: When the remote command took to long to execute. 298 SshTimeoutError: When the connection took to long to established. 299 SshPermissionDeniedError: When permission is not allowed on the 300 remote host. 301 """ 302 command = '(%s) < /dev/null > /dev/null 2>&1 & echo -n $!' % command 303 result = self.run(command, env=env) 304 return result 305 306 def close(self): 307 """Clean up open connections to remote host.""" 308 self._cleanup_master_ssh() 309 while self._tunnels: 310 self.close_ssh_tunnel(self._tunnels[0].local_port) 311 312 def _cleanup_master_ssh(self): 313 """ 314 Release all resources (process, temporary directory) used by an active 315 master SSH connection. 316 """ 317 # If a master SSH connection is running, kill it. 318 if self._master_ssh_proc is not None: 319 self.log.debug('Nuking master_ssh_job.') 320 self._master_ssh_proc.kill() 321 self._master_ssh_proc.wait() 322 self._master_ssh_proc = None 323 324 # Remove the temporary directory for the master SSH socket. 325 if self._master_ssh_tempdir is not None: 326 self.log.debug('Cleaning master_ssh_tempdir.') 327 shutil.rmtree(self._master_ssh_tempdir) 328 self._master_ssh_tempdir = None 329 330 def create_ssh_tunnel(self, port, local_port=None): 331 """Create an ssh tunnel from local_port to port. 332 333 This securely forwards traffic from local_port on this machine to the 334 remote SSH host at port. 335 336 Args: 337 port: remote port on the host. 338 local_port: local forwarding port, or None to pick an available 339 port. 340 341 Returns: 342 the created tunnel process. 343 """ 344 if not local_port: 345 local_port = host_utils.get_available_host_port() 346 else: 347 for tunnel in self._tunnels: 348 if tunnel.remote_port == port: 349 return tunnel.local_port 350 351 extra_flags = { 352 '-n': None, # Read from /dev/null for stdin 353 '-N': None, # Do not execute a remote command 354 '-q': None, # Suppress warnings and diagnostic commands 355 '-L': '%d:localhost:%d' % (local_port, port), 356 } 357 extra_options = dict() 358 if self._master_ssh_proc: 359 extra_options['ControlPath'] = self.socket_path 360 tunnel_cmd = self._formatter.format_ssh_local_command( 361 self._settings, 362 extra_flags=extra_flags, 363 extra_options=extra_options) 364 self.log.debug('Full tunnel command: %s', tunnel_cmd) 365 # Exec the ssh process directly so that when we deliver signals, we 366 # deliver them straight to the child process. 367 tunnel_proc = job.run_async(tunnel_cmd) 368 self.log.debug('Started ssh tunnel, local = %d remote = %d, pid = %d', 369 local_port, port, tunnel_proc.pid) 370 self._tunnels.append(_Tunnel(local_port, port, tunnel_proc)) 371 return local_port 372 373 def close_ssh_tunnel(self, local_port): 374 """Close a previously created ssh tunnel of a TCP port. 375 376 Args: 377 local_port: int port on localhost previously forwarded to the remote 378 host. 379 380 Returns: 381 integer port number this port was forwarded to on the remote host or 382 None if no tunnel was found. 383 """ 384 idx = None 385 for i, tunnel in enumerate(self._tunnels): 386 if tunnel.local_port == local_port: 387 idx = i 388 break 389 if idx is not None: 390 tunnel = self._tunnels.pop(idx) 391 tunnel.proc.kill() 392 tunnel.proc.wait() 393 return tunnel.remote_port 394 return None 395 396 def send_file(self, local_path, remote_path, ignore_status=False): 397 """Send a file from the local host to the remote host. 398 399 Args: 400 local_path: string path of file to send on local host. 401 remote_path: string path to copy file to on remote host. 402 ignore_status: Whether or not to ignore the command's exit_status. 403 """ 404 # TODO: This may belong somewhere else: b/32572515 405 user_host = self._formatter.format_host_name(self._settings) 406 job.run( 407 'scp %s %s:%s' % (local_path, user_host, remote_path), 408 ignore_status=ignore_status) 409 410 def pull_file(self, local_path, remote_path, ignore_status=False): 411 """Send a file from remote host to local host 412 413 Args: 414 local_path: string path of file to recv on local host 415 remote_path: string path to copy file from on remote host. 416 ignore_status: Whether or not to ignore the command's exit_status. 417 """ 418 user_host = self._formatter.format_host_name(self._settings) 419 job.run( 420 'scp %s:%s %s' % (user_host, remote_path, local_path), 421 ignore_status=ignore_status) 422 423 def find_free_port(self, interface_name='localhost'): 424 """Find a unused port on the remote host. 425 426 Note that this method is inherently racy, since it is impossible 427 to promise that the remote port will remain free. 428 429 Args: 430 interface_name: string name of interface to check whether a 431 port is used against. 432 433 Returns: 434 integer port number on remote interface that was free. 435 """ 436 # TODO: This may belong somewhere else: b/3257251 437 free_port_cmd = ( 438 'python -c "import socket; s=socket.socket(); ' 439 's.bind((\'%s\', 0)); print(s.getsockname()[1]); s.close()"' 440 ) % interface_name 441 port = int(self.run(free_port_cmd).stdout) 442 # Yield to the os to ensure the port gets cleaned up. 443 time.sleep(0.001) 444 return port 445