1# Copyright 2013 The Chromium Authors. All rights reserved. 2# Use of this source code is governed by a BSD-style license that can be 3# found in the LICENSE file. 4 5"""Thread and ThreadGroup that reraise exceptions on the main thread.""" 6# pylint: disable=W0212 7 8import logging 9import sys 10import threading 11import time 12import traceback 13 14from devil.utils import watchdog_timer 15 16 17class TimeoutError(Exception): 18 """Module-specific timeout exception.""" 19 pass 20 21 22def LogThreadStack(thread, error_log_func=logging.critical): 23 """Log the stack for the given thread. 24 25 Args: 26 thread: a threading.Thread instance. 27 error_log_func: Logging function when logging errors. 28 """ 29 stack = sys._current_frames()[thread.ident] 30 error_log_func('*' * 80) 31 error_log_func('Stack dump for thread %r', thread.name) 32 error_log_func('*' * 80) 33 for filename, lineno, name, line in traceback.extract_stack(stack): 34 error_log_func('File: "%s", line %d, in %s', filename, lineno, name) 35 if line: 36 error_log_func(' %s', line.strip()) 37 error_log_func('*' * 80) 38 39 40class ReraiserThread(threading.Thread): 41 """Thread class that can reraise exceptions.""" 42 43 def __init__(self, func, args=None, kwargs=None, name=None): 44 """Initialize thread. 45 46 Args: 47 func: callable to call on a new thread. 48 args: list of positional arguments for callable, defaults to empty. 49 kwargs: dictionary of keyword arguments for callable, defaults to empty. 50 name: thread name, defaults to Thread-N. 51 """ 52 if not name and func.__name__ != '<lambda>': 53 name = func.__name__ 54 super(ReraiserThread, self).__init__(name=name) 55 if not args: 56 args = [] 57 if not kwargs: 58 kwargs = {} 59 self.daemon = True 60 self._func = func 61 self._args = args 62 self._kwargs = kwargs 63 self._ret = None 64 self._exc_info = None 65 self._thread_group = None 66 67 def ReraiseIfException(self): 68 """Reraise exception if an exception was raised in the thread.""" 69 if self._exc_info: 70 raise self._exc_info[0], self._exc_info[1], self._exc_info[2] 71 72 def GetReturnValue(self): 73 """Reraise exception if present, otherwise get the return value.""" 74 self.ReraiseIfException() 75 return self._ret 76 77 # override 78 def run(self): 79 """Overrides Thread.run() to add support for reraising exceptions.""" 80 try: 81 self._ret = self._func(*self._args, **self._kwargs) 82 except: # pylint: disable=W0702 83 self._exc_info = sys.exc_info() 84 85 86class ReraiserThreadGroup(object): 87 """A group of ReraiserThread objects.""" 88 89 def __init__(self, threads=None): 90 """Initialize thread group. 91 92 Args: 93 threads: a list of ReraiserThread objects; defaults to empty. 94 """ 95 self._threads = [] 96 # Set when a thread from one group has called JoinAll on another. It is used 97 # to detect when a there is a TimeoutRetryThread active that links to the 98 # current thread. 99 self.blocked_parent_thread_group = None 100 if threads: 101 for thread in threads: 102 self.Add(thread) 103 104 def Add(self, thread): 105 """Add a thread to the group. 106 107 Args: 108 thread: a ReraiserThread object. 109 """ 110 assert thread._thread_group is None 111 thread._thread_group = self 112 self._threads.append(thread) 113 114 def StartAll(self, will_block=False): 115 """Start all threads. 116 117 Args: 118 will_block: Whether the calling thread will subsequently block on this 119 thread group. Causes the active ReraiserThreadGroup (if there is one) 120 to be marked as blocking on this thread group. 121 """ 122 if will_block: 123 # Multiple threads blocking on the same outer thread should not happen in 124 # practice. 125 assert not self.blocked_parent_thread_group 126 self.blocked_parent_thread_group = CurrentThreadGroup() 127 for thread in self._threads: 128 thread.start() 129 130 def _JoinAll(self, watcher=None, timeout=None): 131 """Join all threads without stack dumps. 132 133 Reraises exceptions raised by the child threads and supports breaking 134 immediately on exceptions raised on the main thread. 135 136 Args: 137 watcher: Watchdog object providing the thread timeout. If none is 138 provided, the thread will never be timed out. 139 timeout: An optional number of seconds to wait before timing out the join 140 operation. This will not time out the threads. 141 """ 142 if watcher is None: 143 watcher = watchdog_timer.WatchdogTimer(None) 144 alive_threads = self._threads[:] 145 end_time = (time.time() + timeout) if timeout else None 146 try: 147 while alive_threads and (end_time is None or end_time > time.time()): 148 for thread in alive_threads[:]: 149 if watcher.IsTimedOut(): 150 raise TimeoutError('Timed out waiting for %d of %d threads.' % 151 (len(alive_threads), len(self._threads))) 152 # Allow the main thread to periodically check for interrupts. 153 thread.join(0.1) 154 if not thread.isAlive(): 155 alive_threads.remove(thread) 156 # All threads are allowed to complete before reraising exceptions. 157 for thread in self._threads: 158 thread.ReraiseIfException() 159 finally: 160 self.blocked_parent_thread_group = None 161 162 def IsAlive(self): 163 """Check whether any of the threads are still alive. 164 165 Returns: 166 Whether any of the threads are still alive. 167 """ 168 return any(t.isAlive() for t in self._threads) 169 170 def JoinAll(self, watcher=None, timeout=None, 171 error_log_func=logging.critical): 172 """Join all threads. 173 174 Reraises exceptions raised by the child threads and supports breaking 175 immediately on exceptions raised on the main thread. Unfinished threads' 176 stacks will be logged on watchdog timeout. 177 178 Args: 179 watcher: Watchdog object providing the thread timeout. If none is 180 provided, the thread will never be timed out. 181 timeout: An optional number of seconds to wait before timing out the join 182 operation. This will not time out the threads. 183 error_log_func: Logging function when logging errors. 184 """ 185 try: 186 self._JoinAll(watcher, timeout) 187 except TimeoutError: 188 error_log_func('Timed out. Dumping threads.') 189 for thread in (t for t in self._threads if t.isAlive()): 190 LogThreadStack(thread, error_log_func=error_log_func) 191 raise 192 193 def GetAllReturnValues(self, watcher=None): 194 """Get all return values, joining all threads if necessary. 195 196 Args: 197 watcher: same as in |JoinAll|. Only used if threads are alive. 198 """ 199 if any([t.isAlive() for t in self._threads]): 200 self.JoinAll(watcher) 201 return [t.GetReturnValue() for t in self._threads] 202 203 204def CurrentThreadGroup(): 205 """Returns the ReraiserThreadGroup that owns the running thread. 206 207 Returns: 208 The current thread group, otherwise None. 209 """ 210 current_thread = threading.current_thread() 211 if isinstance(current_thread, ReraiserThread): 212 return current_thread._thread_group # pylint: disable=no-member 213 return None 214 215 216def RunAsync(funcs, watcher=None): 217 """Executes the given functions in parallel and returns their results. 218 219 Args: 220 funcs: List of functions to perform on their own threads. 221 watcher: Watchdog object providing timeout, by default waits forever. 222 223 Returns: 224 A list of return values in the order of the given functions. 225 """ 226 thread_group = ReraiserThreadGroup(ReraiserThread(f) for f in funcs) 227 thread_group.StartAll(will_block=True) 228 return thread_group.GetAllReturnValues(watcher=watcher) 229