• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2013 The Chromium Authors. All rights reserved.
2# Use of this source code is governed by a BSD-style license that can be
3# found in the LICENSE file.
4"""A wrapper around ssh for common operations on a CrOS-based device"""
5import logging
6import os
7import re
8import shutil
9import subprocess
10import sys
11import tempfile
12
13# TODO(nduca): This whole file is built up around making individual ssh calls
14# for each operation. It really could get away with a single ssh session built
15# around pexpect, I suspect, if we wanted it to be faster. But, this was
16# convenient.
17
18# Some developers' workflow includes running the Chrome process from
19# /usr/local/... instead of the default location. We have to check for both
20# paths in order to support this workflow.
21_CHROME_PATHS = ['/opt/google/chrome/chrome ',
22                '/usr/local/opt/google/chrome/chrome ']
23
24def IsRunningOnCrosDevice():
25  """Returns True if we're on a ChromeOS device."""
26  lsb_release = '/etc/lsb-release'
27  if sys.platform.startswith('linux') and os.path.exists(lsb_release):
28    with open(lsb_release, 'r') as f:
29      res = f.read()
30      if res.count('CHROMEOS_RELEASE_NAME'):
31        return True
32  return False
33
34def RunCmd(args, cwd=None, quiet=False):
35  """Opens a subprocess to execute a program and returns its return value.
36
37  Args:
38    args: A string or a sequence of program arguments. The program to execute is
39      the string or the first item in the args sequence.
40    cwd: If not None, the subprocess's current directory will be changed to
41      |cwd| before it's executed.
42
43  Returns:
44    Return code from the command execution.
45  """
46  if not quiet:
47    logging.debug(' '.join(args) + ' ' + (cwd or ''))
48  with open(os.devnull, 'w') as devnull:
49    p = subprocess.Popen(args=args, cwd=cwd, stdout=devnull,
50                         stderr=devnull, stdin=devnull, shell=False)
51    return p.wait()
52
53def GetAllCmdOutput(args, cwd=None, quiet=False):
54  """Open a subprocess to execute a program and returns its output.
55
56  Args:
57    args: A string or a sequence of program arguments. The program to execute is
58      the string or the first item in the args sequence.
59    cwd: If not None, the subprocess's current directory will be changed to
60      |cwd| before it's executed.
61
62  Returns:
63    Captures and returns the command's stdout.
64    Prints the command's stderr to logger (which defaults to stdout).
65  """
66  if not quiet:
67    logging.debug(' '.join(args) + ' ' + (cwd or ''))
68  with open(os.devnull, 'w') as devnull:
69    p = subprocess.Popen(args=args, cwd=cwd, stdout=subprocess.PIPE,
70                         stderr=subprocess.PIPE, stdin=devnull)
71    stdout, stderr = p.communicate()
72    if not quiet:
73      logging.debug(' > stdout=[%s], stderr=[%s]', stdout, stderr)
74    return stdout, stderr
75
76def HasSSH():
77  try:
78    RunCmd(['ssh'], quiet=True)
79    RunCmd(['scp'], quiet=True)
80    logging.debug("HasSSH()->True")
81    return True
82  except OSError:
83    logging.debug("HasSSH()->False")
84    return False
85
86class LoginException(Exception):
87  pass
88
89class KeylessLoginRequiredException(LoginException):
90  pass
91
92class CrOSInterface(object):
93  # pylint: disable=R0923
94  def __init__(self, hostname = None, ssh_identity = None):
95    self._hostname = hostname
96    # List of ports generated from GetRemotePort() that may not be in use yet.
97    self._reserved_ports = []
98
99    if self.local:
100      return
101
102    self._ssh_identity = None
103    self._ssh_args = ['-o ConnectTimeout=5',
104                      '-o StrictHostKeyChecking=no',
105                      '-o KbdInteractiveAuthentication=no',
106                      '-o PreferredAuthentications=publickey',
107                      '-o UserKnownHostsFile=/dev/null']
108
109    if ssh_identity:
110      self._ssh_identity = os.path.abspath(os.path.expanduser(ssh_identity))
111
112  @property
113  def local(self):
114    return not self._hostname
115
116  @property
117  def hostname(self):
118    return self._hostname
119
120  def FormSSHCommandLine(self, args, extra_ssh_args=None):
121    if self.local:
122      # We run the command through the shell locally for consistency with
123      # how commands are run through SSH (crbug.com/239161). This work
124      # around will be unnecessary once we implement a persistent SSH
125      # connection to run remote commands (crbug.com/239607).
126      return ['sh', '-c', " ".join(args)]
127
128    full_args = ['ssh',
129                 '-o ForwardX11=no',
130                 '-o ForwardX11Trusted=no',
131                 '-n'] + self._ssh_args
132    if self._ssh_identity is not None:
133      full_args.extend(['-i', self._ssh_identity])
134    if extra_ssh_args:
135      full_args.extend(extra_ssh_args)
136    full_args.append('root@%s' % self._hostname)
137    full_args.extend(args)
138    return full_args
139
140  def _RemoveSSHWarnings(self, toClean):
141    """Removes specific ssh warning lines from a string.
142
143    Args:
144      toClean: A string that may be containing multiple lines.
145
146    Returns:
147      A copy of toClean with all the Warning lines removed.
148    """
149    # Remove the Warning about connecting to a new host for the first time.
150    return re.sub('Warning: Permanently added [^\n]* to the list of known '
151                  'hosts.\s\n', '', toClean)
152
153  def RunCmdOnDevice(self, args, cwd=None, quiet=False):
154    stdout, stderr = GetAllCmdOutput(
155        self.FormSSHCommandLine(args), cwd, quiet=quiet)
156    # The initial login will add the host to the hosts file but will also print
157    # a warning to stderr that we need to remove.
158    stderr = self._RemoveSSHWarnings(stderr)
159    return stdout, stderr
160
161  def TryLogin(self):
162    logging.debug('TryLogin()')
163    assert not self.local
164    stdout, stderr = self.RunCmdOnDevice(['echo', '$USER'], quiet=True)
165    if stderr != '':
166      if 'Host key verification failed' in stderr:
167        raise LoginException(('%s host key verification failed. ' +
168                             'SSH to it manually to fix connectivity.') %
169            self._hostname)
170      if 'Operation timed out' in stderr:
171        raise LoginException('Timed out while logging into %s' % self._hostname)
172      if 'UNPROTECTED PRIVATE KEY FILE!' in stderr:
173        raise LoginException('Permissions for %s are too open. To fix this,\n'
174                             'chmod 600 %s' % (self._ssh_identity,
175                                               self._ssh_identity))
176      if 'Permission denied (publickey,keyboard-interactive)' in stderr:
177        raise KeylessLoginRequiredException(
178          'Need to set up ssh auth for %s' % self._hostname)
179      raise LoginException('While logging into %s, got %s' % (
180          self._hostname, stderr))
181    if stdout != 'root\n':
182      raise LoginException(
183        'Logged into %s, expected $USER=root, but got %s.' % (
184          self._hostname, stdout))
185
186  def FileExistsOnDevice(self, file_name):
187    if self.local:
188      return os.path.exists(file_name)
189
190    stdout, stderr = self.RunCmdOnDevice([
191        'if', 'test', '-e', file_name, ';',
192        'then', 'echo', '1', ';',
193        'fi'
194        ], quiet=True)
195    if stderr != '':
196      if "Connection timed out" in stderr:
197        raise OSError('Machine wasn\'t responding to ssh: %s' %
198                      stderr)
199      raise OSError('Unexpected error: %s' % stderr)
200    exists = stdout == '1\n'
201    logging.debug("FileExistsOnDevice(<text>, %s)->%s" % (file_name, exists))
202    return exists
203
204  def PushFile(self, filename, remote_filename):
205    if self.local:
206      args = ['cp', '-r', filename, remote_filename]
207      stdout, stderr = GetAllCmdOutput(args, quiet=True)
208      if stderr != '':
209        raise OSError('No such file or directory %s' % stderr)
210      return
211
212    args = ['scp', '-r' ] + self._ssh_args
213    if self._ssh_identity:
214      args.extend(['-i', self._ssh_identity])
215
216    args.extend([os.path.abspath(filename),
217                 'root@%s:%s' % (self._hostname, remote_filename)])
218
219    stdout, stderr = GetAllCmdOutput(args, quiet=True)
220    stderr = self._RemoveSSHWarnings(stderr)
221    if stderr != '':
222      raise OSError('No such file or directory %s' % stderr)
223
224  def PushContents(self, text, remote_filename):
225    logging.debug("PushContents(<text>, %s)" % remote_filename)
226    with tempfile.NamedTemporaryFile() as f:
227      f.write(text)
228      f.flush()
229      self.PushFile(f.name, remote_filename)
230
231  def GetFile(self, filename, destfile=None):
232    """Copies a local file |filename| to |destfile| on the device.
233
234    Args:
235      filename: The name of the local source file.
236      destfile: The name of the file to copy to, and if it is not specified
237        then it is the basename of the source file.
238
239    """
240    logging.debug("GetFile(%s, %s)" % (filename, destfile))
241    if self.local:
242      if destfile is not None and destfile != filename:
243        shutil.copyfile(filename, destfile)
244      return
245
246    if destfile is None:
247      destfile = os.path.basename(filename)
248    args = ['scp'] + self._ssh_args
249    if self._ssh_identity:
250      args.extend(['-i', self._ssh_identity])
251
252    args.extend(['root@%s:%s' % (self._hostname, filename),
253                 os.path.abspath(destfile)])
254    stdout, stderr = GetAllCmdOutput(args, quiet=True)
255    stderr = self._RemoveSSHWarnings(stderr)
256    if stderr != '':
257      raise OSError('No such file or directory %s' % stderr)
258
259  def GetFileContents(self, filename):
260    """Get the contents of a file on the device.
261
262    Args:
263      filename: The name of the file on the device.
264
265    Returns:
266      A string containing the contents of the file.
267    """
268    # TODO: handle the self.local case
269    assert not self.local
270    t = tempfile.NamedTemporaryFile()
271    self.GetFile(filename, t.name)
272    with open(t.name, 'r') as f2:
273      res = f2.read()
274      logging.debug("GetFileContents(%s)->%s" % (filename, res))
275      f2.close()
276      return res
277
278  def ListProcesses(self):
279    """Returns (pid, cmd, ppid, state) of all processes on the device."""
280    stdout, stderr = self.RunCmdOnDevice([
281        '/bin/ps', '--no-headers',
282        '-A',
283        '-o', 'pid,ppid,args:4096,state'], quiet=True)
284    assert stderr == '', stderr
285    procs = []
286    for l in stdout.split('\n'): # pylint: disable=E1103
287      if l == '':
288        continue
289      m = re.match('^\s*(\d+)\s+(\d+)\s+(.+)\s+(.+)', l, re.DOTALL)
290      assert m
291      procs.append((int(m.group(1)), m.group(3).rstrip(),
292                    int(m.group(2)), m.group(4)))
293    logging.debug("ListProcesses(<predicate>)->[%i processes]" % len(procs))
294    return procs
295
296  def _GetSessionManagerPid(self, procs):
297    """Returns the pid of the session_manager process, given the list of
298    processes."""
299    for pid, process, _, _ in procs:
300      argv = process.split()
301      if argv and os.path.basename(argv[0]) == 'session_manager':
302        return pid
303    return None
304
305  def GetChromeProcess(self):
306    """Locates the the main chrome browser process.
307
308    Chrome on cros is usually in /opt/google/chrome, but could be in
309    /usr/local/ for developer workflows - debug chrome is too large to fit on
310    rootfs.
311
312    Chrome spawns multiple processes for renderers. pids wrap around after they
313    are exhausted so looking for the smallest pid is not always correct. We
314    locate the session_manager's pid, and look for the chrome process that's an
315    immediate child. This is the main browser process.
316    """
317    procs = self.ListProcesses()
318    session_manager_pid = self._GetSessionManagerPid(procs)
319    if not session_manager_pid:
320      return None
321
322    # Find the chrome process that is the child of the session_manager.
323    for pid, process, ppid, _ in procs:
324      if ppid != session_manager_pid:
325        continue
326      for path in _CHROME_PATHS:
327        if process.startswith(path):
328          return {'pid': pid, 'path': path, 'args': process}
329    return None
330
331  def GetChromePid(self):
332    """Returns pid of main chrome browser process."""
333    result = self.GetChromeProcess()
334    if result and 'pid' in result:
335      return result['pid']
336    return None
337
338  def RmRF(self, filename):
339    logging.debug("rm -rf %s" % filename)
340    self.RunCmdOnDevice(['rm', '-rf', filename], quiet=True)
341
342  def Chown(self, filename):
343    self.RunCmdOnDevice(['chown', '-R', 'chronos:chronos', filename])
344
345  def KillAllMatching(self, predicate):
346    kills = ['kill', '-KILL']
347    for pid, cmd, _, _ in self.ListProcesses():
348      if predicate(cmd):
349        logging.info('Killing %s, pid %d' % cmd, pid)
350        kills.append(pid)
351    logging.debug("KillAllMatching(<predicate>)->%i" % (len(kills) - 2))
352    if len(kills) > 2:
353      self.RunCmdOnDevice(kills, quiet=True)
354    return len(kills) - 2
355
356  def IsServiceRunning(self, service_name):
357    stdout, stderr = self.RunCmdOnDevice([
358        'status', service_name], quiet=True)
359    assert stderr == '', stderr
360    running = 'running, process' in stdout
361    logging.debug("IsServiceRunning(%s)->%s" % (service_name, running))
362    return running
363
364  def GetRemotePort(self):
365    netstat = self.RunCmdOnDevice(['netstat', '-ant'])
366    netstat = netstat[0].split('\n')
367    ports_in_use = []
368
369    for line in netstat[2:]:
370      if not line:
371        continue
372      address_in_use = line.split()[3]
373      port_in_use = address_in_use.split(':')[-1]
374      ports_in_use.append(int(port_in_use))
375
376    ports_in_use.extend(self._reserved_ports)
377
378    new_port = sorted(ports_in_use)[-1] + 1
379    self._reserved_ports.append(new_port)
380
381    return new_port
382
383  def IsHTTPServerRunningOnPort(self, port):
384    wget_output = self.RunCmdOnDevice(
385        ['wget', 'localhost:%i' % (port), '-T1', '-t1'])
386
387    if 'Connection refused' in wget_output[1]:
388      return False
389
390    return True
391
392  def FilesystemMountedAt(self, path):
393    """Returns the filesystem mounted at |path|"""
394    df_out, _ = self.RunCmdOnDevice(['/bin/df', path])
395    df_ary = df_out.split('\n')
396    # 3 lines for title, mount info, and empty line.
397    if len(df_ary) == 3:
398      line_ary = df_ary[1].split()
399      if line_ary:
400        return line_ary[0]
401    return None
402
403  def CryptohomePath(self, user):
404    """Returns the cryptohome mount point for |user|."""
405    stdout, stderr = self.RunCmdOnDevice(
406        ['cryptohome-path', 'user', "'%s'" % user])
407    if stderr != '':
408      raise OSError('cryptohome-path failed: %s' % stderr)
409    return stdout.rstrip()
410
411  def IsCryptohomeMounted(self, username, is_guest):
412    """Returns True iff |user|'s cryptohome is mounted."""
413    profile_path = self.CryptohomePath(username)
414    mount = self.FilesystemMountedAt(profile_path)
415    mount_prefix = 'guestfs' if is_guest else '/home/.shadow/'
416    return mount and mount.startswith(mount_prefix)
417
418  def TakeScreenShot(self, screenshot_prefix):
419    """Takes a screenshot, useful for debugging failures."""
420    # TODO(achuith): Find a better location for screenshots. Cros autotests
421    # upload everything in /var/log so use /var/log/screenshots for now.
422    SCREENSHOT_DIR = '/var/log/screenshots/'
423    SCREENSHOT_EXT = '.png'
424
425    self.RunCmdOnDevice(['mkdir', '-p', SCREENSHOT_DIR])
426    for i in xrange(25):
427      screenshot_file = ('%s%s-%d%s' %
428                         (SCREENSHOT_DIR, screenshot_prefix, i, SCREENSHOT_EXT))
429      if not self.FileExistsOnDevice(screenshot_file):
430        self.RunCmdOnDevice([
431            'DISPLAY=:0.0 XAUTHORITY=/home/chronos/.Xauthority '
432            '/usr/local/bin/import',
433            '-window root',
434            '-depth 8',
435            screenshot_file])
436        return
437    logging.warning('screenshot directory full.')
438
439  def RestartUI(self, clear_enterprise_policy):
440    logging.info('(Re)starting the ui (logs the user out)')
441    if clear_enterprise_policy:
442      self.RunCmdOnDevice(['stop', 'ui'])
443      self.RmRF('/var/lib/whitelist/*')
444      self.RmRF('/home/chronos/Local\ State')
445
446    if self.IsServiceRunning('ui'):
447      self.RunCmdOnDevice(['restart', 'ui'])
448    else:
449      self.RunCmdOnDevice(['start', 'ui'])
450