• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2013 The Chromium Authors. All rights reserved.
2# Use of this source code is governed by a BSD-style license that can be
3# found in the LICENSE file.
4"""A wrapper around ssh for common operations on a CrOS-based device"""
5import logging
6import re
7import os
8import shutil
9import subprocess
10import tempfile
11
12# Some developers' workflow includes running the Chrome process from
13# /usr/local/... instead of the default location. We have to check for both
14# paths in order to support this workflow.
15_CHROME_PATHS = ['/opt/google/chrome/chrome ',
16                '/usr/local/opt/google/chrome/chrome ']
17
18def RunCmd(args, cwd=None, quiet=False):
19  """Opens a subprocess to execute a program and returns its return value.
20
21  Args:
22    args: A string or a sequence of program arguments. The program to execute is
23      the string or the first item in the args sequence.
24    cwd: If not None, the subprocess's current directory will be changed to
25      |cwd| before it's executed.
26
27  Returns:
28    Return code from the command execution.
29  """
30  if not quiet:
31    logging.debug(' '.join(args) + ' ' + (cwd or ''))
32  with open(os.devnull, 'w') as devnull:
33    p = subprocess.Popen(args=args, cwd=cwd, stdout=devnull,
34                         stderr=devnull, stdin=devnull, shell=False)
35    return p.wait()
36
37def GetAllCmdOutput(args, cwd=None, quiet=False):
38  """Open a subprocess to execute a program and returns its output.
39
40  Args:
41    args: A string or a sequence of program arguments. The program to execute is
42      the string or the first item in the args sequence.
43    cwd: If not None, the subprocess's current directory will be changed to
44      |cwd| before it's executed.
45
46  Returns:
47    Captures and returns the command's stdout.
48    Prints the command's stderr to logger (which defaults to stdout).
49  """
50  if not quiet:
51    logging.debug(' '.join(args) + ' ' + (cwd or ''))
52  with open(os.devnull, 'w') as devnull:
53    p = subprocess.Popen(args=args, cwd=cwd, stdout=subprocess.PIPE,
54                         stderr=subprocess.PIPE, stdin=devnull)
55    stdout, stderr = p.communicate()
56    if not quiet:
57      logging.debug(' > stdout=[%s], stderr=[%s]', stdout, stderr)
58    return stdout, stderr
59
60def HasSSH():
61  try:
62    RunCmd(['ssh'], quiet=True)
63    RunCmd(['scp'], quiet=True)
64    logging.debug("HasSSH()->True")
65    return True
66  except OSError:
67    logging.debug("HasSSH()->False")
68    return False
69
70class LoginException(Exception):
71  pass
72
73class KeylessLoginRequiredException(LoginException):
74  pass
75
76class CrOSInterface(object):
77  # pylint: disable=R0923
78  def __init__(self, hostname = None, ssh_identity = None):
79    self._hostname = hostname
80    # List of ports generated from GetRemotePort() that may not be in use yet.
81    self._reserved_ports = []
82
83    if self.local:
84      return
85
86    self._ssh_identity = None
87    self._ssh_args = ['-o ConnectTimeout=5',
88                      '-o StrictHostKeyChecking=no',
89                      '-o KbdInteractiveAuthentication=no',
90                      '-o PreferredAuthentications=publickey',
91                      '-o UserKnownHostsFile=/dev/null',
92                      '-o ControlMaster=no']
93
94    if ssh_identity:
95      self._ssh_identity = os.path.abspath(os.path.expanduser(ssh_identity))
96
97    # Establish master SSH connection using ControlPersist.
98    # Since only one test will be run on a remote host at a time,
99    # the control socket filename can be telemetry@hostname.
100    self._ssh_control_file = '/tmp/' + 'telemetry' + '@' + hostname
101    with open(os.devnull, 'w') as devnull:
102      subprocess.call(self.FormSSHCommandLine(['-M', '-o ControlPersist=yes']),
103                       stdin=devnull, stdout=devnull, stderr=devnull)
104
105  def __enter__(self):
106    return self
107
108  def __exit__(self, *args):
109    self.CloseConnection()
110
111  @property
112  def local(self):
113    return not self._hostname
114
115  @property
116  def hostname(self):
117    return self._hostname
118
119  def FormSSHCommandLine(self, args, extra_ssh_args=None):
120    if self.local:
121      # We run the command through the shell locally for consistency with
122      # how commands are run through SSH (crbug.com/239161). This work
123      # around will be unnecessary once we implement a persistent SSH
124      # connection to run remote commands (crbug.com/239607).
125      return ['sh', '-c', " ".join(args)]
126
127    full_args = ['ssh',
128                 '-o ForwardX11=no',
129                 '-o ForwardX11Trusted=no',
130                 '-n', '-S', self._ssh_control_file] + self._ssh_args
131    if self._ssh_identity is not None:
132      full_args.extend(['-i', self._ssh_identity])
133    if extra_ssh_args:
134      full_args.extend(extra_ssh_args)
135    full_args.append('root@%s' % self._hostname)
136    full_args.extend(args)
137    return full_args
138
139  def _RemoveSSHWarnings(self, toClean):
140    """Removes specific ssh warning lines from a string.
141
142    Args:
143      toClean: A string that may be containing multiple lines.
144
145    Returns:
146      A copy of toClean with all the Warning lines removed.
147    """
148    # Remove the Warning about connecting to a new host for the first time.
149    return re.sub('Warning: Permanently added [^\n]* to the list of known '
150                  'hosts.\s\n', '', toClean)
151
152  def RunCmdOnDevice(self, args, cwd=None, quiet=False):
153    stdout, stderr = GetAllCmdOutput(
154        self.FormSSHCommandLine(args), cwd, quiet=quiet)
155    # The initial login will add the host to the hosts file but will also print
156    # a warning to stderr that we need to remove.
157    stderr = self._RemoveSSHWarnings(stderr)
158    return stdout, stderr
159
160  def TryLogin(self):
161    logging.debug('TryLogin()')
162    assert not self.local
163    stdout, stderr = self.RunCmdOnDevice(['echo', '$USER'], quiet=True)
164    if stderr != '':
165      if 'Host key verification failed' in stderr:
166        raise LoginException(('%s host key verification failed. ' +
167                             'SSH to it manually to fix connectivity.') %
168            self._hostname)
169      if 'Operation timed out' in stderr:
170        raise LoginException('Timed out while logging into %s' % self._hostname)
171      if 'UNPROTECTED PRIVATE KEY FILE!' in stderr:
172        raise LoginException('Permissions for %s are too open. To fix this,\n'
173                             'chmod 600 %s' % (self._ssh_identity,
174                                               self._ssh_identity))
175      if 'Permission denied (publickey,keyboard-interactive)' in stderr:
176        raise KeylessLoginRequiredException(
177          'Need to set up ssh auth for %s' % self._hostname)
178      raise LoginException('While logging into %s, got %s' % (
179          self._hostname, stderr))
180    if stdout != 'root\n':
181      raise LoginException(
182        'Logged into %s, expected $USER=root, but got %s.' % (
183          self._hostname, stdout))
184
185  def FileExistsOnDevice(self, file_name):
186    if self.local:
187      return os.path.exists(file_name)
188
189    stdout, stderr = self.RunCmdOnDevice([
190        'if', 'test', '-e', file_name, ';',
191        'then', 'echo', '1', ';',
192        'fi'
193        ], quiet=True)
194    if stderr != '':
195      if "Connection timed out" in stderr:
196        raise OSError('Machine wasn\'t responding to ssh: %s' %
197                      stderr)
198      raise OSError('Unexpected error: %s' % stderr)
199    exists = stdout == '1\n'
200    logging.debug("FileExistsOnDevice(<text>, %s)->%s" % (file_name, exists))
201    return exists
202
203  def PushFile(self, filename, remote_filename):
204    if self.local:
205      args = ['cp', '-r', filename, remote_filename]
206      stdout, stderr = GetAllCmdOutput(args, quiet=True)
207      if stderr != '':
208        raise OSError('No such file or directory %s' % stderr)
209      return
210
211    args = ['scp', '-r' ] + self._ssh_args
212    if self._ssh_identity:
213      args.extend(['-i', self._ssh_identity])
214
215    args.extend([os.path.abspath(filename),
216                 'root@%s:%s' % (self._hostname, remote_filename)])
217
218    stdout, stderr = GetAllCmdOutput(args, quiet=True)
219    stderr = self._RemoveSSHWarnings(stderr)
220    if stderr != '':
221      raise OSError('No such file or directory %s' % stderr)
222
223  def PushContents(self, text, remote_filename):
224    logging.debug("PushContents(<text>, %s)" % remote_filename)
225    with tempfile.NamedTemporaryFile() as f:
226      f.write(text)
227      f.flush()
228      self.PushFile(f.name, remote_filename)
229
230  def GetFile(self, filename, destfile=None):
231    """Copies a local file |filename| to |destfile| on the device.
232
233    Args:
234      filename: The name of the local source file.
235      destfile: The name of the file to copy to, and if it is not specified
236        then it is the basename of the source file.
237
238    """
239    logging.debug("GetFile(%s, %s)" % (filename, destfile))
240    if self.local:
241      if destfile is not None and destfile != filename:
242        shutil.copyfile(filename, destfile)
243      return
244
245    if destfile is None:
246      destfile = os.path.basename(filename)
247    args = ['scp'] + self._ssh_args
248    if self._ssh_identity:
249      args.extend(['-i', self._ssh_identity])
250
251    args.extend(['root@%s:%s' % (self._hostname, filename),
252                 os.path.abspath(destfile)])
253    stdout, stderr = GetAllCmdOutput(args, quiet=True)
254    stderr = self._RemoveSSHWarnings(stderr)
255    if stderr != '':
256      raise OSError('No such file or directory %s' % stderr)
257
258  def GetFileContents(self, filename):
259    """Get the contents of a file on the device.
260
261    Args:
262      filename: The name of the file on the device.
263
264    Returns:
265      A string containing the contents of the file.
266    """
267    # TODO: handle the self.local case
268    assert not self.local
269    t = tempfile.NamedTemporaryFile()
270    self.GetFile(filename, t.name)
271    with open(t.name, 'r') as f2:
272      res = f2.read()
273      logging.debug("GetFileContents(%s)->%s" % (filename, res))
274      f2.close()
275      return res
276
277  def ListProcesses(self):
278    """Returns (pid, cmd, ppid, state) of all processes on the device."""
279    stdout, stderr = self.RunCmdOnDevice([
280        '/bin/ps', '--no-headers',
281        '-A',
282        '-o', 'pid,ppid,args:4096,state'], quiet=True)
283    assert stderr == '', stderr
284    procs = []
285    for l in stdout.split('\n'): # pylint: disable=E1103
286      if l == '':
287        continue
288      m = re.match('^\s*(\d+)\s+(\d+)\s+(.+)\s+(.+)', l, re.DOTALL)
289      assert m
290      procs.append((int(m.group(1)), m.group(3).rstrip(),
291                    int(m.group(2)), m.group(4)))
292    logging.debug("ListProcesses(<predicate>)->[%i processes]" % len(procs))
293    return procs
294
295  def _GetSessionManagerPid(self, procs):
296    """Returns the pid of the session_manager process, given the list of
297    processes."""
298    for pid, process, _, _ in procs:
299      argv = process.split()
300      if argv and os.path.basename(argv[0]) == 'session_manager':
301        return pid
302    return None
303
304  def GetChromeProcess(self):
305    """Locates the the main chrome browser process.
306
307    Chrome on cros is usually in /opt/google/chrome, but could be in
308    /usr/local/ for developer workflows - debug chrome is too large to fit on
309    rootfs.
310
311    Chrome spawns multiple processes for renderers. pids wrap around after they
312    are exhausted so looking for the smallest pid is not always correct. We
313    locate the session_manager's pid, and look for the chrome process that's an
314    immediate child. This is the main browser process.
315    """
316    procs = self.ListProcesses()
317    session_manager_pid = self._GetSessionManagerPid(procs)
318    if not session_manager_pid:
319      return None
320
321    # Find the chrome process that is the child of the session_manager.
322    for pid, process, ppid, _ in procs:
323      if ppid != session_manager_pid:
324        continue
325      for path in _CHROME_PATHS:
326        if process.startswith(path):
327          return {'pid': pid, 'path': path, 'args': process}
328    return None
329
330  def GetChromePid(self):
331    """Returns pid of main chrome browser process."""
332    result = self.GetChromeProcess()
333    if result and 'pid' in result:
334      return result['pid']
335    return None
336
337  def RmRF(self, filename):
338    logging.debug("rm -rf %s" % filename)
339    self.RunCmdOnDevice(['rm', '-rf', filename], quiet=True)
340
341  def Chown(self, filename):
342    self.RunCmdOnDevice(['chown', '-R', 'chronos:chronos', filename])
343
344  def KillAllMatching(self, predicate):
345    kills = ['kill', '-KILL']
346    for pid, cmd, _, _ in self.ListProcesses():
347      if predicate(cmd):
348        logging.info('Killing %s, pid %d' % cmd, pid)
349        kills.append(pid)
350    logging.debug("KillAllMatching(<predicate>)->%i" % (len(kills) - 2))
351    if len(kills) > 2:
352      self.RunCmdOnDevice(kills, quiet=True)
353    return len(kills) - 2
354
355  def IsServiceRunning(self, service_name):
356    stdout, stderr = self.RunCmdOnDevice([
357        'status', service_name], quiet=True)
358    assert stderr == '', stderr
359    running = 'running, process' in stdout
360    logging.debug("IsServiceRunning(%s)->%s" % (service_name, running))
361    return running
362
363  def GetRemotePort(self):
364    netstat = self.RunCmdOnDevice(['netstat', '-ant'])
365    netstat = netstat[0].split('\n')
366    ports_in_use = []
367
368    for line in netstat[2:]:
369      if not line:
370        continue
371      address_in_use = line.split()[3]
372      port_in_use = address_in_use.split(':')[-1]
373      ports_in_use.append(int(port_in_use))
374
375    ports_in_use.extend(self._reserved_ports)
376
377    new_port = sorted(ports_in_use)[-1] + 1
378    self._reserved_ports.append(new_port)
379
380    return new_port
381
382  def IsHTTPServerRunningOnPort(self, port):
383    wget_output = self.RunCmdOnDevice(
384        ['wget', 'localhost:%i' % (port), '-T1', '-t1'])
385
386    if 'Connection refused' in wget_output[1]:
387      return False
388
389    return True
390
391  def FilesystemMountedAt(self, path):
392    """Returns the filesystem mounted at |path|"""
393    df_out, _ = self.RunCmdOnDevice(['/bin/df', path])
394    df_ary = df_out.split('\n')
395    # 3 lines for title, mount info, and empty line.
396    if len(df_ary) == 3:
397      line_ary = df_ary[1].split()
398      if line_ary:
399        return line_ary[0]
400    return None
401
402  def CryptohomePath(self, user):
403    """Returns the cryptohome mount point for |user|."""
404    stdout, stderr = self.RunCmdOnDevice(
405        ['cryptohome-path', 'user', "'%s'" % user])
406    if stderr != '':
407      raise OSError('cryptohome-path failed: %s' % stderr)
408    return stdout.rstrip()
409
410  def IsCryptohomeMounted(self, username, is_guest):
411    """Returns True iff |user|'s cryptohome is mounted."""
412    profile_path = self.CryptohomePath(username)
413    mount = self.FilesystemMountedAt(profile_path)
414    mount_prefix = 'guestfs' if is_guest else '/home/.shadow/'
415    return mount and mount.startswith(mount_prefix)
416
417  def TakeScreenShot(self, screenshot_prefix):
418    """Takes a screenshot, useful for debugging failures."""
419    # TODO(achuith): Find a better location for screenshots. Cros autotests
420    # upload everything in /var/log so use /var/log/screenshots for now.
421    SCREENSHOT_DIR = '/var/log/screenshots/'
422    SCREENSHOT_EXT = '.png'
423
424    self.RunCmdOnDevice(['mkdir', '-p', SCREENSHOT_DIR])
425    for i in xrange(25):
426      screenshot_file = ('%s%s-%d%s' %
427                         (SCREENSHOT_DIR, screenshot_prefix, i, SCREENSHOT_EXT))
428      if not self.FileExistsOnDevice(screenshot_file):
429        self.RunCmdOnDevice([
430            'DISPLAY=:0.0 XAUTHORITY=/home/chronos/.Xauthority '
431            '/usr/local/bin/import',
432            '-window root',
433            '-depth 8',
434            screenshot_file])
435        return
436    logging.warning('screenshot directory full.')
437
438  def RestartUI(self, clear_enterprise_policy):
439    logging.info('(Re)starting the ui (logs the user out)')
440    if clear_enterprise_policy:
441      self.RunCmdOnDevice(['stop', 'ui'])
442      self.RmRF('/var/lib/whitelist/*')
443      self.RmRF('/home/chronos/Local\ State')
444
445    if self.IsServiceRunning('ui'):
446      self.RunCmdOnDevice(['restart', 'ui'])
447    else:
448      self.RunCmdOnDevice(['start', 'ui'])
449
450  def CloseConnection(self):
451    if not self.local:
452      with open(os.devnull, 'w') as devnull:
453        subprocess.call(self.FormSSHCommandLine(['-O', 'exit', self._hostname]),
454                        stdout=devnull, stderr=devnull)
455