1# Lint as: python2, python3 2# Copyright 2008 Google Inc, Martin J. Bligh <mbligh@google.com>, 3# Benjamin Poirier, Ryan Stutsman 4# Released under the GPL v2 5""" 6Miscellaneous small functions. 7 8DO NOT import this file directly - it is mixed in by server/utils.py, 9import that instead 10""" 11 12from __future__ import absolute_import 13from __future__ import division 14from __future__ import print_function 15 16import atexit, os, re, shutil, textwrap, sys, tempfile, types 17import six 18 19from autotest_lib.client.common_lib import barrier, utils 20from autotest_lib.server import subcommand 21 22 23# A dictionary of pid and a list of tmpdirs for that pid 24__tmp_dirs = {} 25 26 27def scp_remote_escape(filename): 28 """ 29 Escape special characters from a filename so that it can be passed 30 to scp (within double quotes) as a remote file. 31 32 Bis-quoting has to be used with scp for remote files, "bis-quoting" 33 as in quoting x 2 34 scp does not support a newline in the filename 35 36 Args: 37 filename: the filename string to escape. 38 39 Returns: 40 The escaped filename string. The required englobing double 41 quotes are NOT added and so should be added at some point by 42 the caller. 43 """ 44 escape_chars= r' !"$&' "'" r'()*,:;<=>?[\]^`{|}' 45 46 new_name= [] 47 for char in filename: 48 if char in escape_chars: 49 new_name.append("\\%s" % (char,)) 50 else: 51 new_name.append(char) 52 53 return utils.sh_escape("".join(new_name)) 54 55 56def get(location, local_copy = False): 57 """Get a file or directory to a local temporary directory. 58 59 Args: 60 location: the source of the material to get. This source may 61 be one of: 62 * a local file or directory 63 * a URL (http or ftp) 64 * a python file-like object 65 66 Returns: 67 The location of the file or directory where the requested 68 content was saved. This will be contained in a temporary 69 directory on the local host. If the material to get was a 70 directory, the location will contain a trailing '/' 71 """ 72 tmpdir = get_tmp_dir() 73 74 # location is a file-like object 75 if hasattr(location, "read"): 76 tmpfile = os.path.join(tmpdir, "file") 77 tmpfileobj = open(tmpfile, 'w') 78 shutil.copyfileobj(location, tmpfileobj) 79 tmpfileobj.close() 80 return tmpfile 81 82 if isinstance(location, six.string_types): 83 # location is a URL 84 if location.startswith('http') or location.startswith('ftp'): 85 tmpfile = os.path.join(tmpdir, os.path.basename(location)) 86 utils.urlretrieve(location, tmpfile) 87 return tmpfile 88 # location is a local path 89 elif os.path.exists(os.path.abspath(location)): 90 if not local_copy: 91 if os.path.isdir(location): 92 return location.rstrip('/') + '/' 93 else: 94 return location 95 tmpfile = os.path.join(tmpdir, os.path.basename(location)) 96 if os.path.isdir(location): 97 tmpfile += '/' 98 shutil.copytree(location, tmpfile, symlinks=True) 99 return tmpfile 100 shutil.copyfile(location, tmpfile) 101 return tmpfile 102 # location is just a string, dump it to a file 103 else: 104 tmpfd, tmpfile = tempfile.mkstemp(dir=tmpdir) 105 tmpfileobj = os.fdopen(tmpfd, 'w') 106 tmpfileobj.write(location) 107 tmpfileobj.close() 108 return tmpfile 109 110 111def get_tmp_dir(): 112 """Return the pathname of a directory on the host suitable 113 for temporary file storage. 114 115 The directory and its content will be deleted automatically 116 at the end of the program execution if they are still present. 117 """ 118 dir_name = tempfile.mkdtemp(prefix="autoserv-") 119 pid = os.getpid() 120 if not pid in __tmp_dirs: 121 __tmp_dirs[pid] = [] 122 __tmp_dirs[pid].append(dir_name) 123 return dir_name 124 125 126def __clean_tmp_dirs(): 127 """Erase temporary directories that were created by the get_tmp_dir() 128 function and that are still present. 129 """ 130 pid = os.getpid() 131 if pid not in __tmp_dirs: 132 return 133 for dir in __tmp_dirs[pid]: 134 try: 135 shutil.rmtree(dir) 136 except OSError as e: 137 if e.errno == 2: 138 pass 139 __tmp_dirs[pid] = [] 140atexit.register(__clean_tmp_dirs) 141subcommand.subcommand.register_join_hook(lambda _: __clean_tmp_dirs()) 142 143 144def unarchive(host, source_material): 145 """Uncompress and untar an archive on a host. 146 147 If the "source_material" is compresses (according to the file 148 extension) it will be uncompressed. Supported compression formats 149 are gzip and bzip2. Afterwards, if the source_material is a tar 150 archive, it will be untarred. 151 152 Args: 153 host: the host object on which the archive is located 154 source_material: the path of the archive on the host 155 156 Returns: 157 The file or directory name of the unarchived source material. 158 If the material is a tar archive, it will be extracted in the 159 directory where it is and the path returned will be the first 160 entry in the archive, assuming it is the topmost directory. 161 If the material is not an archive, nothing will be done so this 162 function is "harmless" when it is "useless". 163 """ 164 # uncompress 165 if (source_material.endswith(".gz") or 166 source_material.endswith(".gzip")): 167 host.run('gunzip "%s"' % (utils.sh_escape(source_material))) 168 source_material= ".".join(source_material.split(".")[:-1]) 169 elif source_material.endswith("bz2"): 170 host.run('bunzip2 "%s"' % (utils.sh_escape(source_material))) 171 source_material= ".".join(source_material.split(".")[:-1]) 172 173 # untar 174 if source_material.endswith(".tar"): 175 retval= host.run('tar -C "%s" -xvf "%s"' % ( 176 utils.sh_escape(os.path.dirname(source_material)), 177 utils.sh_escape(source_material),)) 178 source_material= os.path.join(os.path.dirname(source_material), 179 retval.stdout.split()[0]) 180 181 return source_material 182 183 184def get_server_dir(): 185 path = os.path.dirname(sys.modules['autotest_lib.server.utils'].__file__) 186 return os.path.abspath(path) 187 188 189def find_pid(command): 190 for line in utils.system_output('ps -eo pid,cmd').rstrip().split('\n'): 191 (pid, cmd) = line.split(None, 1) 192 if re.search(command, cmd): 193 return int(pid) 194 return None 195 196 197def default_mappings(machines): 198 """ 199 Returns a simple mapping in which all machines are assigned to the 200 same key. Provides the default behavior for 201 form_ntuples_from_machines. """ 202 mappings = {} 203 failures = [] 204 205 mach = machines[0] 206 mappings['ident'] = [mach] 207 if len(machines) > 1: 208 machines = machines[1:] 209 for machine in machines: 210 mappings['ident'].append(machine) 211 212 return (mappings, failures) 213 214 215def form_ntuples_from_machines(machines, n=2, mapping_func=default_mappings): 216 """Returns a set of ntuples from machines where the machines in an 217 ntuple are in the same mapping, and a set of failures which are 218 (machine name, reason) tuples.""" 219 ntuples = [] 220 (mappings, failures) = mapping_func(machines) 221 222 # now run through the mappings and create n-tuples. 223 # throw out the odd guys out 224 for key in mappings: 225 key_machines = mappings[key] 226 total_machines = len(key_machines) 227 228 # form n-tuples 229 while len(key_machines) >= n: 230 ntuples.append(key_machines[0:n]) 231 key_machines = key_machines[n:] 232 233 for mach in key_machines: 234 failures.append((mach, "machine can not be tupled")) 235 236 return (ntuples, failures) 237 238 239def parse_machine(machine, user='root', password='', port=22): 240 """ 241 Parse the machine string user:pass@host:port and return it separately, 242 if the machine string is not complete, use the default parameters 243 when appropriate. 244 """ 245 246 if '@' in machine: 247 user, machine = machine.split('@', 1) 248 249 if ':' in user: 250 user, password = user.split(':', 1) 251 252 # Brackets are required to protect an IPv6 address whenever a 253 # [xx::xx]:port number (or a file [xx::xx]:/path/) is appended to 254 # it. Do not attempt to extract a (non-existent) port number from 255 # an unprotected/bare IPv6 address "xx::xx". 256 # In the Python >= 3.3 future, 'import ipaddress' will parse 257 # addresses; and maybe more. 258 bare_ipv6 = '[' != machine[0] and re.search(r':.*:', machine) 259 260 # Extract trailing :port number if any. 261 if not bare_ipv6 and re.search(r':\d*$', machine): 262 machine, port = machine.rsplit(':', 1) 263 port = int(port) 264 265 # Strip any IPv6 brackets (ssh does not support them). 266 # We'll add them back later for rsync, scp, etc. 267 if machine[0] == '[' and machine[-1] == ']': 268 machine = machine[1:-1] 269 270 if not machine or not user: 271 raise ValueError 272 273 return machine, user, password, port 274 275 276def get_public_key(): 277 """ 278 Return a valid string ssh public key for the user executing autoserv or 279 autotest. If there's no DSA or RSA public key, create a DSA keypair with 280 ssh-keygen and return it. 281 """ 282 283 ssh_conf_path = os.path.expanduser('~/.ssh') 284 285 dsa_public_key_path = os.path.join(ssh_conf_path, 'id_dsa.pub') 286 dsa_private_key_path = os.path.join(ssh_conf_path, 'id_dsa') 287 288 rsa_public_key_path = os.path.join(ssh_conf_path, 'id_rsa.pub') 289 rsa_private_key_path = os.path.join(ssh_conf_path, 'id_rsa') 290 291 has_dsa_keypair = os.path.isfile(dsa_public_key_path) and \ 292 os.path.isfile(dsa_private_key_path) 293 has_rsa_keypair = os.path.isfile(rsa_public_key_path) and \ 294 os.path.isfile(rsa_private_key_path) 295 296 if has_dsa_keypair: 297 print('DSA keypair found, using it') 298 public_key_path = dsa_public_key_path 299 300 elif has_rsa_keypair: 301 print('RSA keypair found, using it') 302 public_key_path = rsa_public_key_path 303 304 else: 305 print('Neither RSA nor DSA keypair found, creating DSA ssh key pair') 306 utils.system('ssh-keygen -t dsa -q -N "" -f %s' % dsa_private_key_path) 307 public_key_path = dsa_public_key_path 308 309 public_key = open(public_key_path, 'r') 310 public_key_str = public_key.read() 311 public_key.close() 312 313 return public_key_str 314