1# Copyright 2016 - The Android Open Source Project 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14 15import collections 16import logging 17import os 18import re 19import shutil 20import tempfile 21import threading 22import time 23import uuid 24 25from acts.controllers.utils_lib import host_utils 26from acts.controllers.utils_lib.ssh import formatter 27from acts.libs.proc import job 28 29 30class Error(Exception): 31 """An error occured during an ssh operation.""" 32 33 34class CommandError(Exception): 35 """An error occured with the command. 36 37 Attributes: 38 result: The results of the ssh command that had the error. 39 """ 40 41 def __init__(self, result): 42 """ 43 Args: 44 result: The result of the ssh command that created the problem. 45 """ 46 self.result = result 47 48 def __str__(self): 49 return 'cmd: %s\nstdout: %s\nstderr: %s' % ( 50 self.result.command, self.result.stdout, self.result.stderr) 51 52 53_Tunnel = collections.namedtuple('_Tunnel', 54 ['local_port', 'remote_port', 'proc']) 55 56 57class SshConnection(object): 58 """Provides a connection to a remote machine through ssh. 59 60 Provides the ability to connect to a remote machine and execute a command 61 on it. The connection will try to establish a persistent connection When 62 a command is run. If the persistent connection fails it will attempt 63 to connect normally. 64 """ 65 66 @property 67 def socket_path(self): 68 """Returns: The os path to the master socket file.""" 69 return os.path.join(self._master_ssh_tempdir, 'socket') 70 71 def __init__(self, settings): 72 """ 73 Args: 74 settings: The ssh settings to use for this conneciton. 75 formatter: The object that will handle formatting ssh command 76 for use with the background job. 77 """ 78 self._settings = settings 79 self._formatter = formatter.SshFormatter() 80 self._lock = threading.Lock() 81 self._master_ssh_proc = None 82 self._master_ssh_tempdir = None 83 self._tunnels = list() 84 85 def __del__(self): 86 self.close() 87 88 def setup_master_ssh(self, timeout_seconds=5): 89 """Sets up the master ssh connection. 90 91 Sets up the inital master ssh connection if it has not already been 92 started. 93 94 Args: 95 timeout_seconds: The time to wait for the master ssh connection to be made. 96 97 Raises: 98 Error: When setting up the master ssh connection fails. 99 """ 100 with self._lock: 101 if self._master_ssh_proc is not None: 102 socket_path = self.socket_path 103 if (not os.path.exists(socket_path) or 104 self._master_ssh_proc.poll() is not None): 105 logging.debug('Master ssh connection to %s is down.', 106 self._settings.hostname) 107 self._cleanup_master_ssh() 108 109 if self._master_ssh_proc is None: 110 # Create a shared socket in a temp location. 111 self._master_ssh_tempdir = tempfile.mkdtemp( 112 prefix='ssh-master') 113 114 # Setup flags and options for running the master ssh 115 # -N: Do not execute a remote command. 116 # ControlMaster: Spawn a master connection. 117 # ControlPath: The master connection socket path. 118 extra_flags = {'-N': None} 119 extra_options = { 120 'ControlMaster': True, 121 'ControlPath': self.socket_path, 122 'BatchMode': True 123 } 124 125 # Construct the command and start it. 126 master_cmd = self._formatter.format_ssh_local_command( 127 self._settings, 128 extra_flags=extra_flags, 129 extra_options=extra_options) 130 logging.info('Starting master ssh connection to %s', 131 self._settings.hostname) 132 self._master_ssh_proc = job.run_async(master_cmd) 133 134 end_time = time.time() + timeout_seconds 135 136 while time.time() < end_time: 137 if os.path.exists(self.socket_path): 138 break 139 time.sleep(.2) 140 else: 141 self._cleanup_master_ssh() 142 raise Error('Master ssh connection timed out.') 143 144 def run(self, 145 command, 146 timeout=3600, 147 ignore_status=False, 148 env=None, 149 io_encoding='utf-8'): 150 """Runs a remote command over ssh. 151 152 Will ssh to a remote host and run a command. This method will 153 block until the remote command is finished. 154 155 Args: 156 command: The command to execute over ssh. Can be either a string 157 or a list. 158 timeout: number seconds to wait for command to finish. 159 ignore_status: bool True to ignore the exit code of the remote 160 subprocess. Note that if you do ignore status codes, 161 you should handle non-zero exit codes explicitly. 162 env: dict enviroment variables to setup on the remote host. 163 io_encoding: str unicode encoding of command output. 164 165 Returns: 166 A job.Result containing the results of the ssh command. 167 168 Raises: 169 job.TimeoutError: When the remote command took to long to execute. 170 Error: When the ssh connection failed to be created. 171 CommandError: Ssh worked, but the command had an error executing. 172 """ 173 if env is None: 174 env = {} 175 176 try: 177 self.setup_master_ssh(self._settings.connect_timeout) 178 except Error: 179 logging.warning('Failed to create master ssh connection, using ' 180 'normal ssh connection.') 181 182 extra_options = {'BatchMode': True} 183 if self._master_ssh_proc: 184 extra_options['ControlPath'] = self.socket_path 185 186 identifier = str(uuid.uuid4()) 187 full_command = 'echo "CONNECTED: %s"; %s' % (identifier, command) 188 189 terminal_command = self._formatter.format_command( 190 full_command, env, self._settings, extra_options=extra_options) 191 192 dns_retry_count = 2 193 while True: 194 result = job.run(terminal_command, 195 ignore_status=True, 196 timeout=timeout) 197 output = result.stdout 198 199 # Check for a connected message to prevent false negatives. 200 valid_connection = re.search( 201 '^CONNECTED: %s' % identifier, output, flags=re.MULTILINE) 202 if valid_connection: 203 # Remove the first line that contains the connect message. 204 line_index = output.find('\n') 205 real_output = output[line_index + 1:].encode(result._encoding) 206 207 result = job.Result( 208 command=result.command, 209 stdout=real_output, 210 stderr=result._raw_stderr, 211 exit_status=result.exit_status, 212 duration=result.duration, 213 did_timeout=result.did_timeout, 214 encoding=result._encoding) 215 if result.exit_status: 216 raise job.Error(result) 217 return result 218 219 error_string = result.stderr 220 221 had_dns_failure = (result.exit_status == 255 and re.search( 222 r'^ssh: .*: Name or service not known', 223 error_string, 224 flags=re.MULTILINE)) 225 if had_dns_failure: 226 dns_retry_count -= 1 227 if not dns_retry_count: 228 raise Error('DNS failed to find host.', result) 229 logging.debug('Failed to connecto to host, retrying...') 230 else: 231 break 232 233 had_timeout = re.search( 234 r'^ssh: connect to host .* port .*: ' 235 r'Connection timed out\r$', 236 error_string, 237 flags=re.MULTILINE) 238 if had_timeout: 239 raise Error('Ssh timed out.', result) 240 241 permission_denied = 'Permission denied' in error_string 242 if permission_denied: 243 raise Error('Permission denied.', result) 244 245 unknown_host = re.search( 246 r'ssh: Could not resolve hostname .*: ' 247 r'Name or service not known', 248 error_string, 249 flags=re.MULTILINE) 250 if unknown_host: 251 raise Error('Unknown host.', result) 252 253 raise Error('The job failed for unkown reasons.', result) 254 255 def run_async(self, command, env=None): 256 """Starts up a background command over ssh. 257 258 Will ssh to a remote host and startup a command. This method will 259 block until there is confirmation that the remote command has started. 260 261 Args: 262 command: The command to execute over ssh. Can be either a string 263 or a list. 264 env: A dictonary of enviroment variables to setup on the remote 265 host. 266 267 Returns: 268 The result of the command to launch the background job. 269 270 Raises: 271 CmdTimeoutError: When the remote command took to long to execute. 272 SshTimeoutError: When the connection took to long to established. 273 SshPermissionDeniedError: When permission is not allowed on the 274 remote host. 275 """ 276 command = '(%s) < /dev/null > /dev/null 2>&1 & echo -n $!' % command 277 result = self.run(command, env=env) 278 return result 279 280 def close(self): 281 """Clean up open connections to remote host.""" 282 self._cleanup_master_ssh() 283 while self._tunnels: 284 self.close_ssh_tunnel(self._tunnels[0].local_port) 285 286 def _cleanup_master_ssh(self): 287 """ 288 Release all resources (process, temporary directory) used by an active 289 master SSH connection. 290 """ 291 # If a master SSH connection is running, kill it. 292 if self._master_ssh_proc is not None: 293 logging.debug('Nuking master_ssh_job.') 294 self._master_ssh_proc.kill() 295 self._master_ssh_proc.wait() 296 self._master_ssh_proc = None 297 298 # Remove the temporary directory for the master SSH socket. 299 if self._master_ssh_tempdir is not None: 300 logging.debug('Cleaning master_ssh_tempdir.') 301 shutil.rmtree(self._master_ssh_tempdir) 302 self._master_ssh_tempdir = None 303 304 def create_ssh_tunnel(self, port, local_port=None): 305 """Create an ssh tunnel from local_port to port. 306 307 This securely forwards traffic from local_port on this machine to the 308 remote SSH host at port. 309 310 Args: 311 port: remote port on the host. 312 local_port: local forwarding port, or None to pick an available 313 port. 314 315 Returns: 316 the created tunnel process. 317 """ 318 if local_port is None: 319 local_port = host_utils.get_available_host_port() 320 else: 321 for tunnel in self._tunnels: 322 if tunnel.remote_port == port: 323 return tunnel.local_port 324 325 extra_flags = { 326 '-n': None, # Read from /dev/null for stdin 327 '-N': None, # Do not execute a remote command 328 '-q': None, # Suppress warnings and diagnostic commands 329 '-L': '%d:localhost:%d' % (local_port, port), 330 } 331 extra_options = dict() 332 if self._master_ssh_proc: 333 extra_options['ControlPath'] = self.socket_path 334 tunnel_cmd = self._formatter.format_ssh_local_command( 335 self._settings, 336 extra_flags=extra_flags, 337 extra_options=extra_options) 338 logging.debug('Full tunnel command: %s', tunnel_cmd) 339 # Exec the ssh process directly so that when we deliver signals, we 340 # deliver them straight to the child process. 341 tunnel_proc = job.run_async(tunnel_cmd) 342 logging.debug('Started ssh tunnel, local = %d' 343 ' remote = %d, pid = %d', local_port, port, 344 tunnel_proc.pid) 345 self._tunnels.append(_Tunnel(local_port, port, tunnel_proc)) 346 return local_port 347 348 def close_ssh_tunnel(self, local_port): 349 """Close a previously created ssh tunnel of a TCP port. 350 351 Args: 352 local_port: int port on localhost previously forwarded to the remote 353 host. 354 355 Returns: 356 integer port number this port was forwarded to on the remote host or 357 None if no tunnel was found. 358 """ 359 idx = None 360 for i, tunnel in enumerate(self._tunnels): 361 if tunnel.local_port == local_port: 362 idx = i 363 break 364 if idx is not None: 365 tunnel = self._tunnels.pop(idx) 366 tunnel.proc.kill() 367 tunnel.proc.wait() 368 return tunnel.remote_port 369 return None 370 371 def send_file(self, local_path, remote_path): 372 """Send a file from the local host to the remote host. 373 374 Args: 375 local_path: string path of file to send on local host. 376 remote_path: string path to copy file to on remote host. 377 """ 378 # TODO: This may belong somewhere else: b/32572515 379 user_host = self._formatter.format_host_name(self._settings) 380 job.run('scp %s %s:%s' % (local_path, user_host, remote_path)) 381 382 def find_free_port(self, interface_name='localhost'): 383 """Find a unused port on the remote host. 384 385 Note that this method is inherently racy, since it is impossible 386 to promise that the remote port will remain free. 387 388 Args: 389 interface_name: string name of interface to check whether a 390 port is used against. 391 392 Returns: 393 integer port number on remote interface that was free. 394 """ 395 # TODO: This may belong somewhere else: b/3257251 396 free_port_cmd = ( 397 'python -c "import socket; s=socket.socket(); ' 398 's.bind((\'%s\', 0)); print(s.getsockname()[1]); s.close()"' 399 ) % interface_name 400 port = int(self.run(free_port_cmd).stdout) 401 # Yield to the os to ensure the port gets cleaned up. 402 time.sleep(0.001) 403 return port 404