1# Copyright (c) 2015, Google Inc. 2# 3# Permission to use, copy, modify, and/or distribute this software for any 4# purpose with or without fee is hereby granted, provided that the above 5# copyright notice and this permission notice appear in all copies. 6# 7# THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES 8# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF 9# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY 10# SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES 11# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN ACTION 12# OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF OR IN 13# CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. 14"""Extracts archives.""" 15 16import hashlib 17import optparse 18import os 19import os.path 20import tarfile 21import shutil 22import sys 23import zipfile 24 25 26def CheckedJoin(output, path): 27 """ 28 CheckedJoin returns os.path.join(output, path). It does sanity checks to 29 ensure the resulting path is under output, but shouldn't be used on untrusted 30 input. 31 """ 32 path = os.path.normpath(path) 33 if os.path.isabs(path) or path.startswith('.'): 34 raise ValueError(path) 35 return os.path.join(output, path) 36 37 38class FileEntry(object): 39 def __init__(self, path, mode, fileobj): 40 self.path = path 41 self.mode = mode 42 self.fileobj = fileobj 43 44 45class SymlinkEntry(object): 46 def __init__(self, path, mode, target): 47 self.path = path 48 self.mode = mode 49 self.target = target 50 51 52def IterateZip(path): 53 """ 54 IterateZip opens the zip file at path and returns a generator of entry objects 55 for each file in it. 56 """ 57 with zipfile.ZipFile(path, 'r') as zip_file: 58 for info in zip_file.infolist(): 59 if info.filename.endswith('/'): 60 continue 61 yield FileEntry(info.filename, None, zip_file.open(info)) 62 63 64def IterateTar(path, compression): 65 """ 66 IterateTar opens the tar.gz or tar.bz2 file at path and returns a generator of 67 entry objects for each file in it. 68 """ 69 with tarfile.open(path, 'r:' + compression) as tar_file: 70 for info in tar_file: 71 if info.isdir(): 72 pass 73 elif info.issym(): 74 yield SymlinkEntry(info.name, None, info.linkname) 75 elif info.isfile(): 76 yield FileEntry(info.name, info.mode, 77 tar_file.extractfile(info)) 78 else: 79 raise ValueError('Unknown entry type "%s"' % (info.name, )) 80 81 82def main(args): 83 parser = optparse.OptionParser(usage='Usage: %prog ARCHIVE OUTPUT') 84 parser.add_option('--no-prefix', 85 dest='no_prefix', 86 action='store_true', 87 help='Do not remove a prefix from paths in the archive.') 88 options, args = parser.parse_args(args) 89 90 if len(args) != 2: 91 parser.print_help() 92 return 1 93 94 archive, output = args 95 96 if not os.path.exists(archive): 97 # Skip archives that weren't downloaded. 98 return 0 99 100 with open(archive) as f: 101 sha256 = hashlib.sha256() 102 while True: 103 chunk = f.read(1024 * 1024) 104 if not chunk: 105 break 106 sha256.update(chunk) 107 digest = sha256.hexdigest() 108 109 stamp_path = os.path.join(output, ".dawn_archive_digest") 110 if os.path.exists(stamp_path): 111 with open(stamp_path) as f: 112 if f.read().strip() == digest: 113 print "Already up-to-date." 114 return 0 115 116 if archive.endswith('.zip'): 117 entries = IterateZip(archive) 118 elif archive.endswith('.tar.gz'): 119 entries = IterateTar(archive, 'gz') 120 elif archive.endswith('.tar.bz2'): 121 entries = IterateTar(archive, 'bz2') 122 else: 123 raise ValueError(archive) 124 125 try: 126 if os.path.exists(output): 127 print "Removing %s" % (output, ) 128 shutil.rmtree(output) 129 130 print "Extracting %s to %s" % (archive, output) 131 prefix = None 132 num_extracted = 0 133 for entry in entries: 134 # Even on Windows, zip files must always use forward slashes. 135 if '\\' in entry.path or entry.path.startswith('/'): 136 raise ValueError(entry.path) 137 138 if not options.no_prefix: 139 new_prefix, rest = entry.path.split('/', 1) 140 141 # Ensure the archive is consistent. 142 if prefix is None: 143 prefix = new_prefix 144 if prefix != new_prefix: 145 raise ValueError((prefix, new_prefix)) 146 else: 147 rest = entry.path 148 149 # Extract the file into the output directory. 150 fixed_path = CheckedJoin(output, rest) 151 if not os.path.isdir(os.path.dirname(fixed_path)): 152 os.makedirs(os.path.dirname(fixed_path)) 153 if isinstance(entry, FileEntry): 154 with open(fixed_path, 'wb') as out: 155 shutil.copyfileobj(entry.fileobj, out) 156 elif isinstance(entry, SymlinkEntry): 157 os.symlink(entry.target, fixed_path) 158 else: 159 raise TypeError('unknown entry type') 160 161 # Fix up permissions if needbe. 162 # TODO(davidben): To be extra tidy, this should only track the execute bit 163 # as in git. 164 if entry.mode is not None: 165 os.chmod(fixed_path, entry.mode) 166 167 # Print every 100 files, so bots do not time out on large archives. 168 num_extracted += 1 169 if num_extracted % 100 == 0: 170 print "Extracted %d files..." % (num_extracted, ) 171 finally: 172 entries.close() 173 174 with open(stamp_path, 'w') as f: 175 f.write(digest) 176 177 print "Done. Extracted %d files." % (num_extracted, ) 178 return 0 179 180 181if __name__ == '__main__': 182 sys.exit(main(sys.argv[1:])) 183