1# Copyright 2013 The Chromium Authors. All rights reserved. 2# Use of this source code is governed by a BSD-style license that can be 3# found in the LICENSE file. 4"""A wrapper around ssh for common operations on a CrOS-based device""" 5import logging 6import re 7import os 8import shutil 9import subprocess 10import tempfile 11 12# Some developers' workflow includes running the Chrome process from 13# /usr/local/... instead of the default location. We have to check for both 14# paths in order to support this workflow. 15_CHROME_PATHS = ['/opt/google/chrome/chrome ', 16 '/usr/local/opt/google/chrome/chrome '] 17 18def RunCmd(args, cwd=None, quiet=False): 19 """Opens a subprocess to execute a program and returns its return value. 20 21 Args: 22 args: A string or a sequence of program arguments. The program to execute is 23 the string or the first item in the args sequence. 24 cwd: If not None, the subprocess's current directory will be changed to 25 |cwd| before it's executed. 26 27 Returns: 28 Return code from the command execution. 29 """ 30 if not quiet: 31 logging.debug(' '.join(args) + ' ' + (cwd or '')) 32 with open(os.devnull, 'w') as devnull: 33 p = subprocess.Popen(args=args, cwd=cwd, stdout=devnull, 34 stderr=devnull, stdin=devnull, shell=False) 35 return p.wait() 36 37def GetAllCmdOutput(args, cwd=None, quiet=False): 38 """Open a subprocess to execute a program and returns its output. 39 40 Args: 41 args: A string or a sequence of program arguments. The program to execute is 42 the string or the first item in the args sequence. 43 cwd: If not None, the subprocess's current directory will be changed to 44 |cwd| before it's executed. 45 46 Returns: 47 Captures and returns the command's stdout. 48 Prints the command's stderr to logger (which defaults to stdout). 49 """ 50 if not quiet: 51 logging.debug(' '.join(args) + ' ' + (cwd or '')) 52 with open(os.devnull, 'w') as devnull: 53 p = subprocess.Popen(args=args, cwd=cwd, stdout=subprocess.PIPE, 54 stderr=subprocess.PIPE, stdin=devnull) 55 stdout, stderr = p.communicate() 56 if not quiet: 57 logging.debug(' > stdout=[%s], stderr=[%s]', stdout, stderr) 58 return stdout, stderr 59 60def HasSSH(): 61 try: 62 RunCmd(['ssh'], quiet=True) 63 RunCmd(['scp'], quiet=True) 64 logging.debug("HasSSH()->True") 65 return True 66 except OSError: 67 logging.debug("HasSSH()->False") 68 return False 69 70class LoginException(Exception): 71 pass 72 73class KeylessLoginRequiredException(LoginException): 74 pass 75 76class CrOSInterface(object): 77 # pylint: disable=R0923 78 def __init__(self, hostname = None, ssh_identity = None): 79 self._hostname = hostname 80 # List of ports generated from GetRemotePort() that may not be in use yet. 81 self._reserved_ports = [] 82 83 if self.local: 84 return 85 86 self._ssh_identity = None 87 self._ssh_args = ['-o ConnectTimeout=5', 88 '-o StrictHostKeyChecking=no', 89 '-o KbdInteractiveAuthentication=no', 90 '-o PreferredAuthentications=publickey', 91 '-o UserKnownHostsFile=/dev/null', 92 '-o ControlMaster=no'] 93 94 if ssh_identity: 95 self._ssh_identity = os.path.abspath(os.path.expanduser(ssh_identity)) 96 97 # Establish master SSH connection using ControlPersist. 98 # Since only one test will be run on a remote host at a time, 99 # the control socket filename can be telemetry@hostname. 100 self._ssh_control_file = '/tmp/' + 'telemetry' + '@' + hostname 101 with open(os.devnull, 'w') as devnull: 102 subprocess.call(self.FormSSHCommandLine(['-M', '-o ControlPersist=yes']), 103 stdin=devnull, stdout=devnull, stderr=devnull) 104 105 def __enter__(self): 106 return self 107 108 def __exit__(self, *args): 109 self.CloseConnection() 110 111 @property 112 def local(self): 113 return not self._hostname 114 115 @property 116 def hostname(self): 117 return self._hostname 118 119 def FormSSHCommandLine(self, args, extra_ssh_args=None): 120 if self.local: 121 # We run the command through the shell locally for consistency with 122 # how commands are run through SSH (crbug.com/239161). This work 123 # around will be unnecessary once we implement a persistent SSH 124 # connection to run remote commands (crbug.com/239607). 125 return ['sh', '-c', " ".join(args)] 126 127 full_args = ['ssh', 128 '-o ForwardX11=no', 129 '-o ForwardX11Trusted=no', 130 '-n', '-S', self._ssh_control_file] + self._ssh_args 131 if self._ssh_identity is not None: 132 full_args.extend(['-i', self._ssh_identity]) 133 if extra_ssh_args: 134 full_args.extend(extra_ssh_args) 135 full_args.append('root@%s' % self._hostname) 136 full_args.extend(args) 137 return full_args 138 139 def _RemoveSSHWarnings(self, toClean): 140 """Removes specific ssh warning lines from a string. 141 142 Args: 143 toClean: A string that may be containing multiple lines. 144 145 Returns: 146 A copy of toClean with all the Warning lines removed. 147 """ 148 # Remove the Warning about connecting to a new host for the first time. 149 return re.sub('Warning: Permanently added [^\n]* to the list of known ' 150 'hosts.\s\n', '', toClean) 151 152 def RunCmdOnDevice(self, args, cwd=None, quiet=False): 153 stdout, stderr = GetAllCmdOutput( 154 self.FormSSHCommandLine(args), cwd, quiet=quiet) 155 # The initial login will add the host to the hosts file but will also print 156 # a warning to stderr that we need to remove. 157 stderr = self._RemoveSSHWarnings(stderr) 158 return stdout, stderr 159 160 def TryLogin(self): 161 logging.debug('TryLogin()') 162 assert not self.local 163 stdout, stderr = self.RunCmdOnDevice(['echo', '$USER'], quiet=True) 164 if stderr != '': 165 if 'Host key verification failed' in stderr: 166 raise LoginException(('%s host key verification failed. ' + 167 'SSH to it manually to fix connectivity.') % 168 self._hostname) 169 if 'Operation timed out' in stderr: 170 raise LoginException('Timed out while logging into %s' % self._hostname) 171 if 'UNPROTECTED PRIVATE KEY FILE!' in stderr: 172 raise LoginException('Permissions for %s are too open. To fix this,\n' 173 'chmod 600 %s' % (self._ssh_identity, 174 self._ssh_identity)) 175 if 'Permission denied (publickey,keyboard-interactive)' in stderr: 176 raise KeylessLoginRequiredException( 177 'Need to set up ssh auth for %s' % self._hostname) 178 raise LoginException('While logging into %s, got %s' % ( 179 self._hostname, stderr)) 180 if stdout != 'root\n': 181 raise LoginException( 182 'Logged into %s, expected $USER=root, but got %s.' % ( 183 self._hostname, stdout)) 184 185 def FileExistsOnDevice(self, file_name): 186 if self.local: 187 return os.path.exists(file_name) 188 189 stdout, stderr = self.RunCmdOnDevice([ 190 'if', 'test', '-e', file_name, ';', 191 'then', 'echo', '1', ';', 192 'fi' 193 ], quiet=True) 194 if stderr != '': 195 if "Connection timed out" in stderr: 196 raise OSError('Machine wasn\'t responding to ssh: %s' % 197 stderr) 198 raise OSError('Unexpected error: %s' % stderr) 199 exists = stdout == '1\n' 200 logging.debug("FileExistsOnDevice(<text>, %s)->%s" % (file_name, exists)) 201 return exists 202 203 def PushFile(self, filename, remote_filename): 204 if self.local: 205 args = ['cp', '-r', filename, remote_filename] 206 stdout, stderr = GetAllCmdOutput(args, quiet=True) 207 if stderr != '': 208 raise OSError('No such file or directory %s' % stderr) 209 return 210 211 args = ['scp', '-r' ] + self._ssh_args 212 if self._ssh_identity: 213 args.extend(['-i', self._ssh_identity]) 214 215 args.extend([os.path.abspath(filename), 216 'root@%s:%s' % (self._hostname, remote_filename)]) 217 218 stdout, stderr = GetAllCmdOutput(args, quiet=True) 219 stderr = self._RemoveSSHWarnings(stderr) 220 if stderr != '': 221 raise OSError('No such file or directory %s' % stderr) 222 223 def PushContents(self, text, remote_filename): 224 logging.debug("PushContents(<text>, %s)" % remote_filename) 225 with tempfile.NamedTemporaryFile() as f: 226 f.write(text) 227 f.flush() 228 self.PushFile(f.name, remote_filename) 229 230 def GetFile(self, filename, destfile=None): 231 """Copies a local file |filename| to |destfile| on the device. 232 233 Args: 234 filename: The name of the local source file. 235 destfile: The name of the file to copy to, and if it is not specified 236 then it is the basename of the source file. 237 238 """ 239 logging.debug("GetFile(%s, %s)" % (filename, destfile)) 240 if self.local: 241 if destfile is not None and destfile != filename: 242 shutil.copyfile(filename, destfile) 243 return 244 245 if destfile is None: 246 destfile = os.path.basename(filename) 247 args = ['scp'] + self._ssh_args 248 if self._ssh_identity: 249 args.extend(['-i', self._ssh_identity]) 250 251 args.extend(['root@%s:%s' % (self._hostname, filename), 252 os.path.abspath(destfile)]) 253 stdout, stderr = GetAllCmdOutput(args, quiet=True) 254 stderr = self._RemoveSSHWarnings(stderr) 255 if stderr != '': 256 raise OSError('No such file or directory %s' % stderr) 257 258 def GetFileContents(self, filename): 259 """Get the contents of a file on the device. 260 261 Args: 262 filename: The name of the file on the device. 263 264 Returns: 265 A string containing the contents of the file. 266 """ 267 # TODO: handle the self.local case 268 assert not self.local 269 t = tempfile.NamedTemporaryFile() 270 self.GetFile(filename, t.name) 271 with open(t.name, 'r') as f2: 272 res = f2.read() 273 logging.debug("GetFileContents(%s)->%s" % (filename, res)) 274 f2.close() 275 return res 276 277 def ListProcesses(self): 278 """Returns (pid, cmd, ppid, state) of all processes on the device.""" 279 stdout, stderr = self.RunCmdOnDevice([ 280 '/bin/ps', '--no-headers', 281 '-A', 282 '-o', 'pid,ppid,args:4096,state'], quiet=True) 283 assert stderr == '', stderr 284 procs = [] 285 for l in stdout.split('\n'): # pylint: disable=E1103 286 if l == '': 287 continue 288 m = re.match('^\s*(\d+)\s+(\d+)\s+(.+)\s+(.+)', l, re.DOTALL) 289 assert m 290 procs.append((int(m.group(1)), m.group(3).rstrip(), 291 int(m.group(2)), m.group(4))) 292 logging.debug("ListProcesses(<predicate>)->[%i processes]" % len(procs)) 293 return procs 294 295 def _GetSessionManagerPid(self, procs): 296 """Returns the pid of the session_manager process, given the list of 297 processes.""" 298 for pid, process, _, _ in procs: 299 argv = process.split() 300 if argv and os.path.basename(argv[0]) == 'session_manager': 301 return pid 302 return None 303 304 def GetChromeProcess(self): 305 """Locates the the main chrome browser process. 306 307 Chrome on cros is usually in /opt/google/chrome, but could be in 308 /usr/local/ for developer workflows - debug chrome is too large to fit on 309 rootfs. 310 311 Chrome spawns multiple processes for renderers. pids wrap around after they 312 are exhausted so looking for the smallest pid is not always correct. We 313 locate the session_manager's pid, and look for the chrome process that's an 314 immediate child. This is the main browser process. 315 """ 316 procs = self.ListProcesses() 317 session_manager_pid = self._GetSessionManagerPid(procs) 318 if not session_manager_pid: 319 return None 320 321 # Find the chrome process that is the child of the session_manager. 322 for pid, process, ppid, _ in procs: 323 if ppid != session_manager_pid: 324 continue 325 for path in _CHROME_PATHS: 326 if process.startswith(path): 327 return {'pid': pid, 'path': path, 'args': process} 328 return None 329 330 def GetChromePid(self): 331 """Returns pid of main chrome browser process.""" 332 result = self.GetChromeProcess() 333 if result and 'pid' in result: 334 return result['pid'] 335 return None 336 337 def RmRF(self, filename): 338 logging.debug("rm -rf %s" % filename) 339 self.RunCmdOnDevice(['rm', '-rf', filename], quiet=True) 340 341 def Chown(self, filename): 342 self.RunCmdOnDevice(['chown', '-R', 'chronos:chronos', filename]) 343 344 def KillAllMatching(self, predicate): 345 kills = ['kill', '-KILL'] 346 for pid, cmd, _, _ in self.ListProcesses(): 347 if predicate(cmd): 348 logging.info('Killing %s, pid %d' % cmd, pid) 349 kills.append(pid) 350 logging.debug("KillAllMatching(<predicate>)->%i" % (len(kills) - 2)) 351 if len(kills) > 2: 352 self.RunCmdOnDevice(kills, quiet=True) 353 return len(kills) - 2 354 355 def IsServiceRunning(self, service_name): 356 stdout, stderr = self.RunCmdOnDevice([ 357 'status', service_name], quiet=True) 358 assert stderr == '', stderr 359 running = 'running, process' in stdout 360 logging.debug("IsServiceRunning(%s)->%s" % (service_name, running)) 361 return running 362 363 def GetRemotePort(self): 364 netstat = self.RunCmdOnDevice(['netstat', '-ant']) 365 netstat = netstat[0].split('\n') 366 ports_in_use = [] 367 368 for line in netstat[2:]: 369 if not line: 370 continue 371 address_in_use = line.split()[3] 372 port_in_use = address_in_use.split(':')[-1] 373 ports_in_use.append(int(port_in_use)) 374 375 ports_in_use.extend(self._reserved_ports) 376 377 new_port = sorted(ports_in_use)[-1] + 1 378 self._reserved_ports.append(new_port) 379 380 return new_port 381 382 def IsHTTPServerRunningOnPort(self, port): 383 wget_output = self.RunCmdOnDevice( 384 ['wget', 'localhost:%i' % (port), '-T1', '-t1']) 385 386 if 'Connection refused' in wget_output[1]: 387 return False 388 389 return True 390 391 def FilesystemMountedAt(self, path): 392 """Returns the filesystem mounted at |path|""" 393 df_out, _ = self.RunCmdOnDevice(['/bin/df', path]) 394 df_ary = df_out.split('\n') 395 # 3 lines for title, mount info, and empty line. 396 if len(df_ary) == 3: 397 line_ary = df_ary[1].split() 398 if line_ary: 399 return line_ary[0] 400 return None 401 402 def CryptohomePath(self, user): 403 """Returns the cryptohome mount point for |user|.""" 404 stdout, stderr = self.RunCmdOnDevice( 405 ['cryptohome-path', 'user', "'%s'" % user]) 406 if stderr != '': 407 raise OSError('cryptohome-path failed: %s' % stderr) 408 return stdout.rstrip() 409 410 def IsCryptohomeMounted(self, username, is_guest): 411 """Returns True iff |user|'s cryptohome is mounted.""" 412 profile_path = self.CryptohomePath(username) 413 mount = self.FilesystemMountedAt(profile_path) 414 mount_prefix = 'guestfs' if is_guest else '/home/.shadow/' 415 return mount and mount.startswith(mount_prefix) 416 417 def TakeScreenShot(self, screenshot_prefix): 418 """Takes a screenshot, useful for debugging failures.""" 419 # TODO(achuith): Find a better location for screenshots. Cros autotests 420 # upload everything in /var/log so use /var/log/screenshots for now. 421 SCREENSHOT_DIR = '/var/log/screenshots/' 422 SCREENSHOT_EXT = '.png' 423 424 self.RunCmdOnDevice(['mkdir', '-p', SCREENSHOT_DIR]) 425 for i in xrange(25): 426 screenshot_file = ('%s%s-%d%s' % 427 (SCREENSHOT_DIR, screenshot_prefix, i, SCREENSHOT_EXT)) 428 if not self.FileExistsOnDevice(screenshot_file): 429 self.RunCmdOnDevice([ 430 'DISPLAY=:0.0 XAUTHORITY=/home/chronos/.Xauthority ' 431 '/usr/local/bin/import', 432 '-window root', 433 '-depth 8', 434 screenshot_file]) 435 return 436 logging.warning('screenshot directory full.') 437 438 def RestartUI(self, clear_enterprise_policy): 439 logging.info('(Re)starting the ui (logs the user out)') 440 if clear_enterprise_policy: 441 self.RunCmdOnDevice(['stop', 'ui']) 442 self.RmRF('/var/lib/whitelist/*') 443 self.RmRF('/home/chronos/Local\ State') 444 445 if self.IsServiceRunning('ui'): 446 self.RunCmdOnDevice(['restart', 'ui']) 447 else: 448 self.RunCmdOnDevice(['start', 'ui']) 449 450 def CloseConnection(self): 451 if not self.local: 452 with open(os.devnull, 'w') as devnull: 453 subprocess.call(self.FormSSHCommandLine(['-O', 'exit', self._hostname]), 454 stdout=devnull, stderr=devnull) 455