1# Copyright 2008 Google Inc, Martin J. Bligh <mbligh@google.com>, 2# Benjamin Poirier, Ryan Stutsman 3# Released under the GPL v2 4""" 5Miscellaneous small functions. 6 7DO NOT import this file directly - it is mixed in by server/utils.py, 8import that instead 9""" 10 11import atexit, os, re, shutil, textwrap, sys, tempfile, types 12 13from autotest_lib.client.common_lib import barrier, utils 14from autotest_lib.server import subcommand 15 16 17# A dictionary of pid and a list of tmpdirs for that pid 18__tmp_dirs = {} 19 20 21def scp_remote_escape(filename): 22 """ 23 Escape special characters from a filename so that it can be passed 24 to scp (within double quotes) as a remote file. 25 26 Bis-quoting has to be used with scp for remote files, "bis-quoting" 27 as in quoting x 2 28 scp does not support a newline in the filename 29 30 Args: 31 filename: the filename string to escape. 32 33 Returns: 34 The escaped filename string. The required englobing double 35 quotes are NOT added and so should be added at some point by 36 the caller. 37 """ 38 escape_chars= r' !"$&' "'" r'()*,:;<=>?[\]^`{|}' 39 40 new_name= [] 41 for char in filename: 42 if char in escape_chars: 43 new_name.append("\\%s" % (char,)) 44 else: 45 new_name.append(char) 46 47 return utils.sh_escape("".join(new_name)) 48 49 50def get(location, local_copy = False): 51 """Get a file or directory to a local temporary directory. 52 53 Args: 54 location: the source of the material to get. This source may 55 be one of: 56 * a local file or directory 57 * a URL (http or ftp) 58 * a python file-like object 59 60 Returns: 61 The location of the file or directory where the requested 62 content was saved. This will be contained in a temporary 63 directory on the local host. If the material to get was a 64 directory, the location will contain a trailing '/' 65 """ 66 tmpdir = get_tmp_dir() 67 68 # location is a file-like object 69 if hasattr(location, "read"): 70 tmpfile = os.path.join(tmpdir, "file") 71 tmpfileobj = file(tmpfile, 'w') 72 shutil.copyfileobj(location, tmpfileobj) 73 tmpfileobj.close() 74 return tmpfile 75 76 if isinstance(location, types.StringTypes): 77 # location is a URL 78 if location.startswith('http') or location.startswith('ftp'): 79 tmpfile = os.path.join(tmpdir, os.path.basename(location)) 80 utils.urlretrieve(location, tmpfile) 81 return tmpfile 82 # location is a local path 83 elif os.path.exists(os.path.abspath(location)): 84 if not local_copy: 85 if os.path.isdir(location): 86 return location.rstrip('/') + '/' 87 else: 88 return location 89 tmpfile = os.path.join(tmpdir, os.path.basename(location)) 90 if os.path.isdir(location): 91 tmpfile += '/' 92 shutil.copytree(location, tmpfile, symlinks=True) 93 return tmpfile 94 shutil.copyfile(location, tmpfile) 95 return tmpfile 96 # location is just a string, dump it to a file 97 else: 98 tmpfd, tmpfile = tempfile.mkstemp(dir=tmpdir) 99 tmpfileobj = os.fdopen(tmpfd, 'w') 100 tmpfileobj.write(location) 101 tmpfileobj.close() 102 return tmpfile 103 104 105def get_tmp_dir(): 106 """Return the pathname of a directory on the host suitable 107 for temporary file storage. 108 109 The directory and its content will be deleted automatically 110 at the end of the program execution if they are still present. 111 """ 112 dir_name = tempfile.mkdtemp(prefix="autoserv-") 113 pid = os.getpid() 114 if not pid in __tmp_dirs: 115 __tmp_dirs[pid] = [] 116 __tmp_dirs[pid].append(dir_name) 117 return dir_name 118 119 120def __clean_tmp_dirs(): 121 """Erase temporary directories that were created by the get_tmp_dir() 122 function and that are still present. 123 """ 124 pid = os.getpid() 125 if pid not in __tmp_dirs: 126 return 127 for dir in __tmp_dirs[pid]: 128 try: 129 shutil.rmtree(dir) 130 except OSError, e: 131 if e.errno == 2: 132 pass 133 __tmp_dirs[pid] = [] 134atexit.register(__clean_tmp_dirs) 135subcommand.subcommand.register_join_hook(lambda _: __clean_tmp_dirs()) 136 137 138def unarchive(host, source_material): 139 """Uncompress and untar an archive on a host. 140 141 If the "source_material" is compresses (according to the file 142 extension) it will be uncompressed. Supported compression formats 143 are gzip and bzip2. Afterwards, if the source_material is a tar 144 archive, it will be untarred. 145 146 Args: 147 host: the host object on which the archive is located 148 source_material: the path of the archive on the host 149 150 Returns: 151 The file or directory name of the unarchived source material. 152 If the material is a tar archive, it will be extracted in the 153 directory where it is and the path returned will be the first 154 entry in the archive, assuming it is the topmost directory. 155 If the material is not an archive, nothing will be done so this 156 function is "harmless" when it is "useless". 157 """ 158 # uncompress 159 if (source_material.endswith(".gz") or 160 source_material.endswith(".gzip")): 161 host.run('gunzip "%s"' % (utils.sh_escape(source_material))) 162 source_material= ".".join(source_material.split(".")[:-1]) 163 elif source_material.endswith("bz2"): 164 host.run('bunzip2 "%s"' % (utils.sh_escape(source_material))) 165 source_material= ".".join(source_material.split(".")[:-1]) 166 167 # untar 168 if source_material.endswith(".tar"): 169 retval= host.run('tar -C "%s" -xvf "%s"' % ( 170 utils.sh_escape(os.path.dirname(source_material)), 171 utils.sh_escape(source_material),)) 172 source_material= os.path.join(os.path.dirname(source_material), 173 retval.stdout.split()[0]) 174 175 return source_material 176 177 178def get_server_dir(): 179 path = os.path.dirname(sys.modules['autotest_lib.server.utils'].__file__) 180 return os.path.abspath(path) 181 182 183def find_pid(command): 184 for line in utils.system_output('ps -eo pid,cmd').rstrip().split('\n'): 185 (pid, cmd) = line.split(None, 1) 186 if re.search(command, cmd): 187 return int(pid) 188 return None 189 190 191def default_mappings(machines): 192 """ 193 Returns a simple mapping in which all machines are assigned to the 194 same key. Provides the default behavior for 195 form_ntuples_from_machines. """ 196 mappings = {} 197 failures = [] 198 199 mach = machines[0] 200 mappings['ident'] = [mach] 201 if len(machines) > 1: 202 machines = machines[1:] 203 for machine in machines: 204 mappings['ident'].append(machine) 205 206 return (mappings, failures) 207 208 209def form_ntuples_from_machines(machines, n=2, mapping_func=default_mappings): 210 """Returns a set of ntuples from machines where the machines in an 211 ntuple are in the same mapping, and a set of failures which are 212 (machine name, reason) tuples.""" 213 ntuples = [] 214 (mappings, failures) = mapping_func(machines) 215 216 # now run through the mappings and create n-tuples. 217 # throw out the odd guys out 218 for key in mappings: 219 key_machines = mappings[key] 220 total_machines = len(key_machines) 221 222 # form n-tuples 223 while len(key_machines) >= n: 224 ntuples.append(key_machines[0:n]) 225 key_machines = key_machines[n:] 226 227 for mach in key_machines: 228 failures.append((mach, "machine can not be tupled")) 229 230 return (ntuples, failures) 231 232 233def parse_machine(machine, user='root', password='', port=22): 234 """ 235 Parse the machine string user:pass@host:port and return it separately, 236 if the machine string is not complete, use the default parameters 237 when appropriate. 238 """ 239 240 if '@' in machine: 241 user, machine = machine.split('@', 1) 242 243 if ':' in user: 244 user, password = user.split(':', 1) 245 246 # Brackets are required to protect an IPv6 address whenever a 247 # [xx::xx]:port number (or a file [xx::xx]:/path/) is appended to 248 # it. Do not attempt to extract a (non-existent) port number from 249 # an unprotected/bare IPv6 address "xx::xx". 250 # In the Python >= 3.3 future, 'import ipaddress' will parse 251 # addresses; and maybe more. 252 bare_ipv6 = '[' != machine[0] and re.search(r':.*:', machine) 253 254 # Extract trailing :port number if any. 255 if not bare_ipv6 and re.search(r':\d*$', machine): 256 machine, port = machine.rsplit(':', 1) 257 port = int(port) 258 259 # Strip any IPv6 brackets (ssh does not support them). 260 # We'll add them back later for rsync, scp, etc. 261 if machine[0] == '[' and machine[-1] == ']': 262 machine = machine[1:-1] 263 264 if not machine or not user: 265 raise ValueError 266 267 return machine, user, password, port 268 269 270def get_public_key(): 271 """ 272 Return a valid string ssh public key for the user executing autoserv or 273 autotest. If there's no DSA or RSA public key, create a DSA keypair with 274 ssh-keygen and return it. 275 """ 276 277 ssh_conf_path = os.path.expanduser('~/.ssh') 278 279 dsa_public_key_path = os.path.join(ssh_conf_path, 'id_dsa.pub') 280 dsa_private_key_path = os.path.join(ssh_conf_path, 'id_dsa') 281 282 rsa_public_key_path = os.path.join(ssh_conf_path, 'id_rsa.pub') 283 rsa_private_key_path = os.path.join(ssh_conf_path, 'id_rsa') 284 285 has_dsa_keypair = os.path.isfile(dsa_public_key_path) and \ 286 os.path.isfile(dsa_private_key_path) 287 has_rsa_keypair = os.path.isfile(rsa_public_key_path) and \ 288 os.path.isfile(rsa_private_key_path) 289 290 if has_dsa_keypair: 291 print 'DSA keypair found, using it' 292 public_key_path = dsa_public_key_path 293 294 elif has_rsa_keypair: 295 print 'RSA keypair found, using it' 296 public_key_path = rsa_public_key_path 297 298 else: 299 print 'Neither RSA nor DSA keypair found, creating DSA ssh key pair' 300 utils.system('ssh-keygen -t dsa -q -N "" -f %s' % dsa_private_key_path) 301 public_key_path = dsa_public_key_path 302 303 public_key = open(public_key_path, 'r') 304 public_key_str = public_key.read() 305 public_key.close() 306 307 return public_key_str 308