• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2013 The Chromium Authors. All rights reserved.
2# Use of this source code is governed by a BSD-style license that can be
3# found in the LICENSE file.
4"""A wrapper around ssh for common operations on a CrOS-based device"""
5import logging
6import os
7import re
8import shutil
9import stat
10import subprocess
11import tempfile
12
13# Some developers' workflow includes running the Chrome process from
14# /usr/local/... instead of the default location. We have to check for both
15# paths in order to support this workflow.
16_CHROME_PROCESS_REGEX = [re.compile(r'^/opt/google/chrome/chrome '),
17                         re.compile(r'^/usr/local/?.*/chrome/chrome ')]
18
19
20def RunCmd(args, cwd=None, quiet=False):
21  """Opens a subprocess to execute a program and returns its return value.
22
23  Args:
24    args: A string or a sequence of program arguments. The program to execute is
25      the string or the first item in the args sequence.
26    cwd: If not None, the subprocess's current directory will be changed to
27      |cwd| before it's executed.
28
29  Returns:
30    Return code from the command execution.
31  """
32  if not quiet:
33    logging.debug(' '.join(args) + ' ' + (cwd or ''))
34  with open(os.devnull, 'w') as devnull:
35    p = subprocess.Popen(args=args,
36                         cwd=cwd,
37                         stdout=devnull,
38                         stderr=devnull,
39                         stdin=devnull,
40                         shell=False)
41    return p.wait()
42
43
44def GetAllCmdOutput(args, cwd=None, quiet=False):
45  """Open a subprocess to execute a program and returns its output.
46
47  Args:
48    args: A string or a sequence of program arguments. The program to execute is
49      the string or the first item in the args sequence.
50    cwd: If not None, the subprocess's current directory will be changed to
51      |cwd| before it's executed.
52
53  Returns:
54    Captures and returns the command's stdout.
55    Prints the command's stderr to logger (which defaults to stdout).
56  """
57  if not quiet:
58    logging.debug(' '.join(args) + ' ' + (cwd or ''))
59  with open(os.devnull, 'w') as devnull:
60    p = subprocess.Popen(args=args,
61                         cwd=cwd,
62                         stdout=subprocess.PIPE,
63                         stderr=subprocess.PIPE,
64                         stdin=devnull)
65    stdout, stderr = p.communicate()
66    if not quiet:
67      logging.debug(' > stdout=[%s], stderr=[%s]', stdout, stderr)
68    return stdout, stderr
69
70
71def HasSSH():
72  try:
73    RunCmd(['ssh'], quiet=True)
74    RunCmd(['scp'], quiet=True)
75    logging.debug("HasSSH()->True")
76    return True
77  except OSError:
78    logging.debug("HasSSH()->False")
79    return False
80
81
82class LoginException(Exception):
83  pass
84
85
86class KeylessLoginRequiredException(LoginException):
87  pass
88
89
90class DNSFailureException(LoginException):
91  pass
92
93
94class CrOSInterface(object):
95
96  def __init__(self, hostname=None, ssh_port=None, ssh_identity=None):
97    self._hostname = hostname
98    self._ssh_port = ssh_port
99
100    # List of ports generated from GetRemotePort() that may not be in use yet.
101    self._reserved_ports = []
102
103    if self.local:
104      return
105
106    self._ssh_identity = None
107    self._ssh_args = ['-o ConnectTimeout=5', '-o StrictHostKeyChecking=no',
108                      '-o KbdInteractiveAuthentication=no',
109                      '-o PreferredAuthentications=publickey',
110                      '-o UserKnownHostsFile=/dev/null', '-o ControlMaster=no']
111
112    if ssh_identity:
113      self._ssh_identity = os.path.abspath(os.path.expanduser(ssh_identity))
114      os.chmod(self._ssh_identity, stat.S_IREAD)
115
116    # Establish master SSH connection using ControlPersist.
117    # Since only one test will be run on a remote host at a time,
118    # the control socket filename can be telemetry@hostname.
119    self._ssh_control_file = '/tmp/' + 'telemetry' + '@' + hostname
120    with open(os.devnull, 'w') as devnull:
121      subprocess.call(
122          self.FormSSHCommandLine(['-M', '-o ControlPersist=yes']),
123          stdin=devnull,
124          stdout=devnull,
125          stderr=devnull)
126
127  def __enter__(self):
128    return self
129
130  def __exit__(self, *args):
131    self.CloseConnection()
132
133  @property
134  def local(self):
135    return not self._hostname
136
137  @property
138  def hostname(self):
139    return self._hostname
140
141  @property
142  def ssh_port(self):
143    return self._ssh_port
144
145  def FormSSHCommandLine(self, args, extra_ssh_args=None):
146    """Constructs a subprocess-suitable command line for `ssh'.
147    """
148    if self.local:
149      # We run the command through the shell locally for consistency with
150      # how commands are run through SSH (crbug.com/239161). This work
151      # around will be unnecessary once we implement a persistent SSH
152      # connection to run remote commands (crbug.com/239607).
153      return ['sh', '-c', " ".join(args)]
154
155    full_args = ['ssh', '-o ForwardX11=no', '-o ForwardX11Trusted=no', '-n',
156                 '-S', self._ssh_control_file] + self._ssh_args
157    if self._ssh_identity is not None:
158      full_args.extend(['-i', self._ssh_identity])
159    if extra_ssh_args:
160      full_args.extend(extra_ssh_args)
161    full_args.append('root@%s' % self._hostname)
162    full_args.append('-p%d' % self._ssh_port)
163    full_args.extend(args)
164    return full_args
165
166  def _FormSCPCommandLine(self, src, dst, extra_scp_args=None):
167    """Constructs a subprocess-suitable command line for `scp'.
168
169    Note: this function is not designed to work with IPv6 addresses, which need
170    to have their addresses enclosed in brackets and a '-6' flag supplied
171    in order to be properly parsed by `scp'.
172    """
173    assert not self.local, "Cannot use SCP on local target."
174
175    args = ['scp', '-P', str(self._ssh_port)] + self._ssh_args
176    if self._ssh_identity:
177      args.extend(['-i', self._ssh_identity])
178    if extra_scp_args:
179      args.extend(extra_scp_args)
180    args += [src, dst]
181    return args
182
183  def _FormSCPToRemote(self,
184                       source,
185                       remote_dest,
186                       extra_scp_args=None,
187                       user='root'):
188    return self._FormSCPCommandLine(source,
189                                    '%s@%s:%s' % (user, self._hostname,
190                                                  remote_dest),
191                                    extra_scp_args=extra_scp_args)
192
193  def _FormSCPFromRemote(self,
194                         remote_source,
195                         dest,
196                         extra_scp_args=None,
197                         user='root'):
198    return self._FormSCPCommandLine('%s@%s:%s' % (user, self._hostname,
199                                                  remote_source),
200                                    dest,
201                                    extra_scp_args=extra_scp_args)
202
203  def _RemoveSSHWarnings(self, toClean):
204    """Removes specific ssh warning lines from a string.
205
206    Args:
207      toClean: A string that may be containing multiple lines.
208
209    Returns:
210      A copy of toClean with all the Warning lines removed.
211    """
212    # Remove the Warning about connecting to a new host for the first time.
213    return re.sub(
214        r'Warning: Permanently added [^\n]* to the list of known hosts.\s\n',
215        '', toClean)
216
217  def RunCmdOnDevice(self, args, cwd=None, quiet=False):
218    stdout, stderr = GetAllCmdOutput(
219        self.FormSSHCommandLine(args),
220        cwd,
221        quiet=quiet)
222    # The initial login will add the host to the hosts file but will also print
223    # a warning to stderr that we need to remove.
224    stderr = self._RemoveSSHWarnings(stderr)
225    return stdout, stderr
226
227  def TryLogin(self):
228    logging.debug('TryLogin()')
229    assert not self.local
230    stdout, stderr = self.RunCmdOnDevice(['echo', '$USER'], quiet=True)
231    if stderr != '':
232      if 'Host key verification failed' in stderr:
233        raise LoginException(('%s host key verification failed. ' +
234                              'SSH to it manually to fix connectivity.') %
235                             self._hostname)
236      if 'Operation timed out' in stderr:
237        raise LoginException('Timed out while logging into %s' % self._hostname)
238      if 'UNPROTECTED PRIVATE KEY FILE!' in stderr:
239        raise LoginException('Permissions for %s are too open. To fix this,\n'
240                             'chmod 600 %s' % (self._ssh_identity,
241                                               self._ssh_identity))
242      if 'Permission denied (publickey,keyboard-interactive)' in stderr:
243        raise KeylessLoginRequiredException('Need to set up ssh auth for %s' %
244                                            self._hostname)
245      if 'Could not resolve hostname' in stderr:
246        raise DNSFailureException('Unable to resolve the hostname for: %s' %
247                                  self._hostname)
248      raise LoginException('While logging into %s, got %s' % (self._hostname,
249                                                              stderr))
250    if stdout != 'root\n':
251      raise LoginException('Logged into %s, expected $USER=root, but got %s.' %
252                           (self._hostname, stdout))
253
254  def FileExistsOnDevice(self, file_name):
255    if self.local:
256      return os.path.exists(file_name)
257
258    stdout, stderr = self.RunCmdOnDevice(
259        [
260            'if', 'test', '-e', file_name, ';', 'then', 'echo', '1', ';', 'fi'
261        ],
262        quiet=True)
263    if stderr != '':
264      if "Connection timed out" in stderr:
265        raise OSError('Machine wasn\'t responding to ssh: %s' % stderr)
266      raise OSError('Unexpected error: %s' % stderr)
267    exists = stdout == '1\n'
268    logging.debug("FileExistsOnDevice(<text>, %s)->%s" % (file_name, exists))
269    return exists
270
271  def PushFile(self, filename, remote_filename):
272    if self.local:
273      args = ['cp', '-r', filename, remote_filename]
274      stdout, stderr = GetAllCmdOutput(args, quiet=True)
275      if stderr != '':
276        raise OSError('No such file or directory %s' % stderr)
277      return
278
279    args = self._FormSCPToRemote(
280        os.path.abspath(filename),
281        remote_filename,
282        extra_scp_args=['-r'])
283
284    stdout, stderr = GetAllCmdOutput(args, quiet=True)
285    stderr = self._RemoveSSHWarnings(stderr)
286    if stderr != '':
287      raise OSError('No such file or directory %s' % stderr)
288
289  def PushContents(self, text, remote_filename):
290    logging.debug("PushContents(<text>, %s)" % remote_filename)
291    with tempfile.NamedTemporaryFile() as f:
292      f.write(text)
293      f.flush()
294      self.PushFile(f.name, remote_filename)
295
296  def GetFile(self, filename, destfile=None):
297    """Copies a local file |filename| to |destfile| on the device.
298
299    Args:
300      filename: The name of the local source file.
301      destfile: The name of the file to copy to, and if it is not specified
302        then it is the basename of the source file.
303
304    """
305    logging.debug("GetFile(%s, %s)" % (filename, destfile))
306    if self.local:
307      if destfile is not None and destfile != filename:
308        shutil.copyfile(filename, destfile)
309        return
310      else:
311        raise OSError('No such file or directory %s' % filename)
312
313    if destfile is None:
314      destfile = os.path.basename(filename)
315    args = self._FormSCPFromRemote(filename, os.path.abspath(destfile))
316
317    stdout, stderr = GetAllCmdOutput(args, quiet=True)
318    stderr = self._RemoveSSHWarnings(stderr)
319    if stderr != '':
320      raise OSError('No such file or directory %s' % stderr)
321
322  def GetFileContents(self, filename):
323    """Get the contents of a file on the device.
324
325    Args:
326      filename: The name of the file on the device.
327
328    Returns:
329      A string containing the contents of the file.
330    """
331    with tempfile.NamedTemporaryFile() as t:
332      self.GetFile(filename, t.name)
333      with open(t.name, 'r') as f2:
334        res = f2.read()
335        logging.debug("GetFileContents(%s)->%s" % (filename, res))
336        return res
337
338  def ListProcesses(self):
339    """Returns (pid, cmd, ppid, state) of all processes on the device."""
340    stdout, stderr = self.RunCmdOnDevice(
341        [
342            '/bin/ps', '--no-headers', '-A', '-o', 'pid,ppid,args:4096,state'
343        ],
344        quiet=True)
345    assert stderr == '', stderr
346    procs = []
347    for l in stdout.split('\n'):
348      if l == '':
349        continue
350      m = re.match(r'^\s*(\d+)\s+(\d+)\s+(.+)\s+(.+)', l, re.DOTALL)
351      assert m
352      procs.append((int(m.group(1)), m.group(3).rstrip(), int(m.group(2)),
353                    m.group(4)))
354    logging.debug("ListProcesses(<predicate>)->[%i processes]" % len(procs))
355    return procs
356
357  def _GetSessionManagerPid(self, procs):
358    """Returns the pid of the session_manager process, given the list of
359    processes."""
360    for pid, process, _, _ in procs:
361      argv = process.split()
362      if argv and os.path.basename(argv[0]) == 'session_manager':
363        return pid
364    return None
365
366  def GetChromeProcess(self):
367    """Locates the the main chrome browser process.
368
369    Chrome on cros is usually in /opt/google/chrome, but could be in
370    /usr/local/ for developer workflows - debug chrome is too large to fit on
371    rootfs.
372
373    Chrome spawns multiple processes for renderers. pids wrap around after they
374    are exhausted so looking for the smallest pid is not always correct. We
375    locate the session_manager's pid, and look for the chrome process that's an
376    immediate child. This is the main browser process.
377    """
378    procs = self.ListProcesses()
379    session_manager_pid = self._GetSessionManagerPid(procs)
380    if not session_manager_pid:
381      return None
382
383    # Find the chrome process that is the child of the session_manager.
384    for pid, process, ppid, _ in procs:
385      if ppid != session_manager_pid:
386        continue
387      for regex in _CHROME_PROCESS_REGEX:
388        path_match = re.match(regex, process)
389        if path_match is not None:
390          return {'pid': pid, 'path': path_match.group(), 'args': process}
391    return None
392
393  def GetChromePid(self):
394    """Returns pid of main chrome browser process."""
395    result = self.GetChromeProcess()
396    if result and 'pid' in result:
397      return result['pid']
398    return None
399
400  def RmRF(self, filename):
401    logging.debug("rm -rf %s" % filename)
402    self.RunCmdOnDevice(['rm', '-rf', filename], quiet=True)
403
404  def Chown(self, filename):
405    self.RunCmdOnDevice(['chown', '-R', 'chronos:chronos', filename])
406
407  def KillAllMatching(self, predicate):
408    kills = ['kill', '-KILL']
409    for pid, cmd, _, _ in self.ListProcesses():
410      if predicate(cmd):
411        logging.info('Killing %s, pid %d' % cmd, pid)
412        kills.append(pid)
413    logging.debug("KillAllMatching(<predicate>)->%i" % (len(kills) - 2))
414    if len(kills) > 2:
415      self.RunCmdOnDevice(kills, quiet=True)
416    return len(kills) - 2
417
418  def IsServiceRunning(self, service_name):
419    stdout, stderr = self.RunCmdOnDevice(['status', service_name], quiet=True)
420    assert stderr == '', stderr
421    running = 'running, process' in stdout
422    logging.debug("IsServiceRunning(%s)->%s" % (service_name, running))
423    return running
424
425  def GetRemotePort(self):
426    netstat = self.RunCmdOnDevice(['netstat', '-ant'])
427    netstat = netstat[0].split('\n')
428    ports_in_use = []
429
430    for line in netstat[2:]:
431      if not line:
432        continue
433      address_in_use = line.split()[3]
434      port_in_use = address_in_use.split(':')[-1]
435      ports_in_use.append(int(port_in_use))
436
437    ports_in_use.extend(self._reserved_ports)
438
439    new_port = sorted(ports_in_use)[-1] + 1
440    self._reserved_ports.append(new_port)
441
442    return new_port
443
444  def IsHTTPServerRunningOnPort(self, port):
445    wget_output = self.RunCmdOnDevice(['wget', 'localhost:%i' % (port), '-T1',
446                                       '-t1'])
447
448    if 'Connection refused' in wget_output[1]:
449      return False
450
451    return True
452
453  def FilesystemMountedAt(self, path):
454    """Returns the filesystem mounted at |path|"""
455    df_out, _ = self.RunCmdOnDevice(['/bin/df', path])
456    df_ary = df_out.split('\n')
457    # 3 lines for title, mount info, and empty line.
458    if len(df_ary) == 3:
459      line_ary = df_ary[1].split()
460      if line_ary:
461        return line_ary[0]
462    return None
463
464  def CryptohomePath(self, user):
465    """Returns the cryptohome mount point for |user|."""
466    stdout, stderr = self.RunCmdOnDevice(['cryptohome-path', 'user', "'%s'" %
467                                          user])
468    if stderr != '':
469      raise OSError('cryptohome-path failed: %s' % stderr)
470    return stdout.rstrip()
471
472  def IsCryptohomeMounted(self, username, is_guest):
473    """Returns True iff |user|'s cryptohome is mounted."""
474    profile_path = self.CryptohomePath(username)
475    mount = self.FilesystemMountedAt(profile_path)
476    mount_prefix = 'guestfs' if is_guest else '/home/.shadow/'
477    return mount and mount.startswith(mount_prefix)
478
479  def TakeScreenshot(self, file_path):
480    stdout, stderr = self.RunCmdOnDevice(
481        ['/usr/local/autotest/bin/screenshot.py', file_path])
482    return stdout == '' and stderr == ''
483
484  def TakeScreenshotWithPrefix(self, screenshot_prefix):
485    """Takes a screenshot, useful for debugging failures."""
486    # TODO(achuith): Find a better location for screenshots. Cros autotests
487    # upload everything in /var/log so use /var/log/screenshots for now.
488    SCREENSHOT_DIR = '/var/log/screenshots/'
489    SCREENSHOT_EXT = '.png'
490
491    self.RunCmdOnDevice(['mkdir', '-p', SCREENSHOT_DIR])
492    # Large number of screenshots can increase hardware lab bandwidth
493    # dramatically, so keep this number low. crbug.com/524814.
494    for i in xrange(2):
495      screenshot_file = ('%s%s-%d%s' %
496                         (SCREENSHOT_DIR, screenshot_prefix, i, SCREENSHOT_EXT))
497      if not self.FileExistsOnDevice(screenshot_file):
498        return self.TakeScreenshot(screenshot_file)
499    logging.warning('screenshot directory full.')
500    return False
501
502  def GetArchName(self):
503    return self.RunCmdOnDevice(['uname', '-m'])[0]
504
505  def IsRunningOnVM(self):
506    return self.RunCmdOnDevice(['crossystem', 'inside_vm'])[0] != '0'
507
508  def LsbReleaseValue(self, key, default):
509    """/etc/lsb-release is a file with key=value pairs."""
510    lines = self.GetFileContents('/etc/lsb-release').split('\n')
511    for l in lines:
512      m = re.match(r'([^=]*)=(.*)', l)
513      if m and m.group(1) == key:
514        return m.group(2)
515    return default
516
517  def GetDeviceTypeName(self):
518    """DEVICETYPE in /etc/lsb-release is CHROMEBOOK, CHROMEBIT, etc."""
519    return self.LsbReleaseValue(key='DEVICETYPE', default='CHROMEBOOK')
520
521  def RestartUI(self, clear_enterprise_policy):
522    logging.info('(Re)starting the ui (logs the user out)')
523    if clear_enterprise_policy:
524      self.RunCmdOnDevice(['stop', 'ui'])
525      self.RmRF('/var/lib/whitelist/*')
526      self.RmRF(r'/home/chronos/Local\ State')
527
528    if self.IsServiceRunning('ui'):
529      self.RunCmdOnDevice(['restart', 'ui'])
530    else:
531      self.RunCmdOnDevice(['start', 'ui'])
532
533  def CloseConnection(self):
534    if not self.local:
535      with open(os.devnull, 'w') as devnull:
536        subprocess.call(
537            self.FormSSHCommandLine(['-O', 'exit', self._hostname]),
538            stdout=devnull,
539            stderr=devnull)
540