1#!/usr/bin/env python2 2# Copyright 2013 Google Inc. All rights reserved. 3# 4# Licensed under the Apache License, Version 2.0 (the "License"); 5# you may not use this file except in compliance with the License. 6# You may obtain a copy of the License at 7# 8# http://www.apache.org/licenses/LICENSE-2.0 9# 10# Unless required by applicable law or agreed to in writing, software 11# distributed under the License is distributed on an "AS IS" BASIS, 12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13# See the License for the specific language governing permissions and 14# limitations under the License. 15import cPickle 16import errno 17import gzip 18import multiprocessing 19import optparse 20import os 21import signal 22import subprocess 23import sys 24import tempfile 25import thread 26import threading 27import time 28import zlib 29 30# An object that catches SIGINT sent to the Python process and notices 31# if processes passed to wait() die by SIGINT (we need to look for 32# both of those cases, because pressing Ctrl+C can result in either 33# the main process or one of the subprocesses getting the signal). 34# 35# Before a SIGINT is seen, wait(p) will simply call p.wait() and 36# return the result. Once a SIGINT has been seen (in the main process 37# or a subprocess, including the one the current call is waiting for), 38# wait(p) will call p.terminate() and raise ProcessWasInterrupted. 39class SigintHandler(object): 40 class ProcessWasInterrupted(Exception): pass 41 sigint_returncodes = {-signal.SIGINT, # Unix 42 -1073741510, # Windows 43 } 44 def __init__(self): 45 self.__lock = threading.Lock() 46 self.__processes = set() 47 self.__got_sigint = False 48 signal.signal(signal.SIGINT, self.__sigint_handler) 49 def __on_sigint(self): 50 self.__got_sigint = True 51 while self.__processes: 52 try: 53 self.__processes.pop().terminate() 54 except OSError: 55 pass 56 def __sigint_handler(self, signal_num, frame): 57 with self.__lock: 58 self.__on_sigint() 59 def got_sigint(self): 60 with self.__lock: 61 return self.__got_sigint 62 def wait(self, p): 63 with self.__lock: 64 if self.__got_sigint: 65 p.terminate() 66 self.__processes.add(p) 67 code = p.wait() 68 with self.__lock: 69 self.__processes.discard(p) 70 if code in self.sigint_returncodes: 71 self.__on_sigint() 72 if self.__got_sigint: 73 raise self.ProcessWasInterrupted 74 return code 75sigint_handler = SigintHandler() 76 77# Return the width of the terminal, or None if it couldn't be 78# determined (e.g. because we're not being run interactively). 79def term_width(out): 80 if not out.isatty(): 81 return None 82 try: 83 p = subprocess.Popen(["stty", "size"], 84 stdout=subprocess.PIPE, stderr=subprocess.PIPE) 85 (out, err) = p.communicate() 86 if p.returncode != 0 or err: 87 return None 88 return int(out.split()[1]) 89 except (IndexError, OSError, ValueError): 90 return None 91 92# Output transient and permanent lines of text. If several transient 93# lines are written in sequence, the new will overwrite the old. We 94# use this to ensure that lots of unimportant info (tests passing) 95# won't drown out important info (tests failing). 96class Outputter(object): 97 def __init__(self, out_file): 98 self.__out_file = out_file 99 self.__previous_line_was_transient = False 100 self.__width = term_width(out_file) # Line width, or None if not a tty. 101 def transient_line(self, msg): 102 if self.__width is None: 103 self.__out_file.write(msg + "\n") 104 else: 105 self.__out_file.write("\r" + msg[:self.__width].ljust(self.__width)) 106 self.__previous_line_was_transient = True 107 def flush_transient_output(self): 108 if self.__previous_line_was_transient: 109 self.__out_file.write("\n") 110 self.__previous_line_was_transient = False 111 def permanent_line(self, msg): 112 self.flush_transient_output() 113 self.__out_file.write(msg + "\n") 114 115stdout_lock = threading.Lock() 116 117class FilterFormat: 118 if sys.stdout.isatty(): 119 # stdout needs to be unbuffered since the output is interactive. 120 sys.stdout = os.fdopen(sys.stdout.fileno(), 'w', 0) 121 122 out = Outputter(sys.stdout) 123 total_tests = 0 124 finished_tests = 0 125 126 tests = {} 127 outputs = {} 128 failures = [] 129 130 def print_test_status(self, last_finished_test, time_ms): 131 self.out.transient_line("[%d/%d] %s (%d ms)" 132 % (self.finished_tests, self.total_tests, 133 last_finished_test, time_ms)) 134 135 def handle_meta(self, job_id, args): 136 (command, arg) = args.split(' ', 1) 137 if command == "TEST": 138 (binary, test) = arg.split(' ', 1) 139 self.tests[job_id] = (binary, test.strip()) 140 elif command == "EXIT": 141 (exit_code, time_ms) = [int(x) for x in arg.split(' ', 1)] 142 self.finished_tests += 1 143 (binary, test) = self.tests[job_id] 144 self.print_test_status(test, time_ms) 145 if exit_code != 0: 146 self.failures.append(self.tests[job_id]) 147 with open(self.outputs[job_id]) as f: 148 for line in f.readlines(): 149 self.out.permanent_line(line.rstrip()) 150 self.out.permanent_line( 151 "[%d/%d] %s returned/aborted with exit code %d (%d ms)" 152 % (self.finished_tests, self.total_tests, test, exit_code, time_ms)) 153 elif command == "TESTCNT": 154 self.total_tests = int(arg.split(' ', 1)[1]) 155 self.out.transient_line("[0/%d] Running tests..." % self.total_tests) 156 157 def logfile(self, job_id, name): 158 self.outputs[job_id] = name 159 160 def log(self, line): 161 stdout_lock.acquire() 162 (prefix, output) = line.split(' ', 1) 163 164 assert prefix[-1] == ':' 165 self.handle_meta(int(prefix[:-1]), output) 166 stdout_lock.release() 167 168 def end(self): 169 if self.failures: 170 self.out.permanent_line("FAILED TESTS (%d/%d):" 171 % (len(self.failures), self.total_tests)) 172 for (binary, test) in self.failures: 173 self.out.permanent_line(" " + binary + ": " + test) 174 self.out.flush_transient_output() 175 176class RawFormat: 177 def log(self, line): 178 stdout_lock.acquire() 179 sys.stdout.write(line + "\n") 180 sys.stdout.flush() 181 stdout_lock.release() 182 def logfile(self, job_id, name): 183 with open(self.outputs[job_id]) as f: 184 for line in f.readlines(): 185 self.log(str(job_id) + '> ' + line.rstrip()) 186 def end(self): 187 pass 188 189# Record of test runtimes. Has built-in locking. 190class TestTimes(object): 191 def __init__(self, save_file): 192 "Create new object seeded with saved test times from the given file." 193 self.__times = {} # (test binary, test name) -> runtime in ms 194 195 # Protects calls to record_test_time(); other calls are not 196 # expected to be made concurrently. 197 self.__lock = threading.Lock() 198 199 try: 200 with gzip.GzipFile(save_file, "rb") as f: 201 times = cPickle.load(f) 202 except (EOFError, IOError, cPickle.UnpicklingError, zlib.error): 203 # File doesn't exist, isn't readable, is malformed---whatever. 204 # Just ignore it. 205 return 206 207 # Discard saved times if the format isn't right. 208 if type(times) is not dict: 209 return 210 for ((test_binary, test_name), runtime) in times.items(): 211 if (type(test_binary) is not str or type(test_name) is not str 212 or type(runtime) not in {int, long, type(None)}): 213 return 214 215 self.__times = times 216 217 def get_test_time(self, binary, testname): 218 """Return the last duration for the given test as an integer number of 219 milliseconds, or None if the test failed or if there's no record for it.""" 220 return self.__times.get((binary, testname), None) 221 222 def record_test_time(self, binary, testname, runtime_ms): 223 """Record that the given test ran in the specified number of 224 milliseconds. If the test failed, runtime_ms should be None.""" 225 with self.__lock: 226 self.__times[(binary, testname)] = runtime_ms 227 228 def write_to_file(self, save_file): 229 "Write all the times to file." 230 try: 231 with open(save_file, "wb") as f: 232 with gzip.GzipFile("", "wb", 9, f) as gzf: 233 cPickle.dump(self.__times, gzf, cPickle.HIGHEST_PROTOCOL) 234 except IOError: 235 pass # ignore errors---saving the times isn't that important 236 237# Remove additional arguments (anything after --). 238additional_args = [] 239 240for i in range(len(sys.argv)): 241 if sys.argv[i] == '--': 242 additional_args = sys.argv[i+1:] 243 sys.argv = sys.argv[:i] 244 break 245 246parser = optparse.OptionParser( 247 usage = 'usage: %prog [options] binary [binary ...] -- [additional args]') 248 249parser.add_option('-d', '--output_dir', type='string', 250 default=os.path.join(tempfile.gettempdir(), "gtest-parallel"), 251 help='output directory for test logs') 252parser.add_option('-r', '--repeat', type='int', default=1, 253 help='repeat tests') 254parser.add_option('-w', '--workers', type='int', 255 default=multiprocessing.cpu_count(), 256 help='number of workers to spawn') 257parser.add_option('--gtest_color', type='string', default='yes', 258 help='color output') 259parser.add_option('--gtest_filter', type='string', default='', 260 help='test filter') 261parser.add_option('--gtest_also_run_disabled_tests', action='store_true', 262 default=False, help='run disabled tests too') 263parser.add_option('--format', type='string', default='filter', 264 help='output format (raw,filter)') 265parser.add_option('--print_test_times', action='store_true', default=False, 266 help='When done, list the run time of each test') 267 268(options, binaries) = parser.parse_args() 269 270if binaries == []: 271 parser.print_usage() 272 sys.exit(1) 273 274logger = RawFormat() 275if options.format == 'raw': 276 pass 277elif options.format == 'filter': 278 logger = FilterFormat() 279else: 280 sys.exit("Unknown output format: " + options.format) 281 282# Find tests. 283save_file = os.path.join(os.path.expanduser("~"), ".gtest-parallel-times") 284times = TestTimes(save_file) 285tests = [] 286for test_binary in binaries: 287 command = [test_binary] 288 if options.gtest_also_run_disabled_tests: 289 command += ['--gtest_also_run_disabled_tests'] 290 291 list_command = list(command) 292 if options.gtest_filter != '': 293 list_command += ['--gtest_filter=' + options.gtest_filter] 294 295 try: 296 test_list = subprocess.Popen(list_command + ['--gtest_list_tests'], 297 stdout=subprocess.PIPE).communicate()[0] 298 except OSError as e: 299 sys.exit("%s: %s" % (test_binary, str(e))) 300 301 command += additional_args 302 303 test_group = '' 304 for line in test_list.split('\n'): 305 if not line.strip(): 306 continue 307 if line[0] != " ": 308 test_group = line.strip() 309 continue 310 # Remove comments for parameterized tests and strip whitespace. 311 line = line.split('#')[0].strip() 312 if not line: 313 continue 314 315 test = test_group + line 316 if not options.gtest_also_run_disabled_tests and 'DISABLED_' in test: 317 continue 318 tests.append((times.get_test_time(test_binary, test), 319 test_binary, test, command)) 320 321# Sort tests by falling runtime (with None, which is what we get for 322# new and failing tests, being considered larger than any real 323# runtime). 324tests.sort(reverse=True, key=lambda x: ((1 if x[0] is None else 0), x)) 325 326# Repeat tests (-r flag). 327tests *= options.repeat 328test_lock = threading.Lock() 329job_id = 0 330logger.log(str(-1) + ': TESTCNT ' + ' ' + str(len(tests))) 331 332exit_code = 0 333 334# Create directory for test log output. 335try: 336 os.makedirs(options.output_dir) 337except OSError as e: 338 # Ignore errors if this directory already exists. 339 if e.errno != errno.EEXIST or not os.path.isdir(options.output_dir): 340 raise e 341# Remove files from old test runs. 342for logfile in os.listdir(options.output_dir): 343 os.remove(os.path.join(options.output_dir, logfile)) 344 345# Run the specified job. Return the elapsed time in milliseconds if 346# the job succeeds, or None if the job fails. (This ensures that 347# failing tests will run first the next time.) 348def run_job((command, job_id, test)): 349 begin = time.time() 350 351 with tempfile.NamedTemporaryFile(dir=options.output_dir, delete=False) as log: 352 sub = subprocess.Popen(command + ['--gtest_filter=' + test] + 353 ['--gtest_color=' + options.gtest_color], 354 stdout=log.file, 355 stderr=log.file) 356 try: 357 code = sigint_handler.wait(sub) 358 except sigint_handler.ProcessWasInterrupted: 359 thread.exit() 360 runtime_ms = int(1000 * (time.time() - begin)) 361 logger.logfile(job_id, log.name) 362 363 logger.log("%s: EXIT %s %d" % (job_id, code, runtime_ms)) 364 if code == 0: 365 return runtime_ms 366 global exit_code 367 exit_code = code 368 return None 369 370def worker(): 371 global job_id 372 while True: 373 job = None 374 test_lock.acquire() 375 if job_id < len(tests): 376 (_, test_binary, test, command) = tests[job_id] 377 logger.log(str(job_id) + ': TEST ' + test_binary + ' ' + test) 378 job = (command, job_id, test) 379 job_id += 1 380 test_lock.release() 381 if job is None: 382 return 383 times.record_test_time(test_binary, test, run_job(job)) 384 385def start_daemon(func): 386 t = threading.Thread(target=func) 387 t.daemon = True 388 t.start() 389 return t 390 391workers = [start_daemon(worker) for i in range(options.workers)] 392 393[t.join() for t in workers] 394logger.end() 395times.write_to_file(save_file) 396if options.print_test_times: 397 ts = sorted((times.get_test_time(test_binary, test), test_binary, test) 398 for (_, test_binary, test, _) in tests 399 if times.get_test_time(test_binary, test) is not None) 400 for (time_ms, test_binary, test) in ts: 401 print "%8s %s" % ("%dms" % time_ms, test) 402sys.exit(-signal.SIGINT if sigint_handler.got_sigint() else exit_code) 403