1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Coordinator to help multiple threads stop when requested.""" 16from __future__ import absolute_import 17from __future__ import division 18from __future__ import print_function 19 20import contextlib 21import sys 22import threading 23import time 24 25import six 26 27from tensorflow.python.framework import errors 28from tensorflow.python.platform import tf_logging as logging 29from tensorflow.python.util import compat 30from tensorflow.python.util.tf_export import tf_export 31 32 33@tf_export("train.Coordinator") 34class Coordinator(object): 35 """A coordinator for threads. 36 37 This class implements a simple mechanism to coordinate the termination of a 38 set of threads. 39 40 #### Usage: 41 42 ```python 43 # Create a coordinator. 44 coord = Coordinator() 45 # Start a number of threads, passing the coordinator to each of them. 46 ...start thread 1...(coord, ...) 47 ...start thread N...(coord, ...) 48 # Wait for all the threads to terminate. 49 coord.join(threads) 50 ``` 51 52 Any of the threads can call `coord.request_stop()` to ask for all the threads 53 to stop. To cooperate with the requests, each thread must check for 54 `coord.should_stop()` on a regular basis. `coord.should_stop()` returns 55 `True` as soon as `coord.request_stop()` has been called. 56 57 A typical thread running with a coordinator will do something like: 58 59 ```python 60 while not coord.should_stop(): 61 ...do some work... 62 ``` 63 64 #### Exception handling: 65 66 A thread can report an exception to the coordinator as part of the 67 `request_stop()` call. The exception will be re-raised from the 68 `coord.join()` call. 69 70 Thread code: 71 72 ```python 73 try: 74 while not coord.should_stop(): 75 ...do some work... 76 except Exception as e: 77 coord.request_stop(e) 78 ``` 79 80 Main code: 81 82 ```python 83 try: 84 ... 85 coord = Coordinator() 86 # Start a number of threads, passing the coordinator to each of them. 87 ...start thread 1...(coord, ...) 88 ...start thread N...(coord, ...) 89 # Wait for all the threads to terminate. 90 coord.join(threads) 91 except Exception as e: 92 ...exception that was passed to coord.request_stop() 93 ``` 94 95 To simplify the thread implementation, the Coordinator provides a 96 context handler `stop_on_exception()` that automatically requests a stop if 97 an exception is raised. Using the context handler the thread code above 98 can be written as: 99 100 ```python 101 with coord.stop_on_exception(): 102 while not coord.should_stop(): 103 ...do some work... 104 ``` 105 106 #### Grace period for stopping: 107 108 After a thread has called `coord.request_stop()` the other threads have a 109 fixed time to stop, this is called the 'stop grace period' and defaults to 2 110 minutes. If any of the threads is still alive after the grace period expires 111 `coord.join()` raises a RuntimeError reporting the laggards. 112 113 ```python 114 try: 115 ... 116 coord = Coordinator() 117 # Start a number of threads, passing the coordinator to each of them. 118 ...start thread 1...(coord, ...) 119 ...start thread N...(coord, ...) 120 # Wait for all the threads to terminate, give them 10s grace period 121 coord.join(threads, stop_grace_period_secs=10) 122 except RuntimeError: 123 ...one of the threads took more than 10s to stop after request_stop() 124 ...was called. 125 except Exception: 126 ...exception that was passed to coord.request_stop() 127 ``` 128 """ 129 130 def __init__(self, clean_stop_exception_types=None): 131 """Create a new Coordinator. 132 133 Args: 134 clean_stop_exception_types: Optional tuple of Exception types that should 135 cause a clean stop of the coordinator. If an exception of one of these 136 types is reported to `request_stop(ex)` the coordinator will behave as 137 if `request_stop(None)` was called. Defaults to 138 `(tf.errors.OutOfRangeError,)` which is used by input queues to signal 139 the end of input. When feeding training data from a Python iterator it 140 is common to add `StopIteration` to this list. 141 """ 142 if clean_stop_exception_types is None: 143 clean_stop_exception_types = (errors.OutOfRangeError,) 144 self._clean_stop_exception_types = tuple(clean_stop_exception_types) 145 # Protects all attributes. 146 self._lock = threading.Lock() 147 # Event set when threads must stop. 148 self._stop_event = threading.Event() 149 # Python exc_info to report. 150 # If not None, it should hold the returned value of sys.exc_info(), which is 151 # a tuple containing exception (type, value, traceback). 152 self._exc_info_to_raise = None 153 # True if we have called join() already. 154 self._joined = False 155 # Set of threads registered for joining when join() is called. These 156 # threads will be joined in addition to the threads passed to the join() 157 # call. It's ok if threads are both registered and passed to the join() 158 # call. 159 self._registered_threads = set() 160 161 def _filter_exception(self, ex): 162 """Check if the exception indicated in 'ex' should be ignored. 163 164 This method examines `ex` to check if it is an exception that should be 165 reported to the users. If yes, it returns `ex` as is, otherwise it returns 166 None. 167 168 The code returns None for exception types listed in 169 `_clean_stop_exception_types`. 170 171 Args: 172 ex: None, an `Exception`, or a Python `exc_info` tuple as returned by 173 `sys.exc_info()`. 174 175 Returns: 176 ex or None. 177 """ 178 if isinstance(ex, tuple): 179 ex2 = ex[1] 180 else: 181 ex2 = ex 182 if isinstance(ex2, self._clean_stop_exception_types): 183 # Ignore the exception. 184 ex = None 185 return ex 186 187 def request_stop(self, ex=None): 188 """Request that the threads stop. 189 190 After this is called, calls to `should_stop()` will return `True`. 191 192 Note: If an exception is being passed in, in must be in the context of 193 handling the exception (i.e. `try: ... except Exception as ex: ...`) and not 194 a newly created one. 195 196 Args: 197 ex: Optional `Exception`, or Python `exc_info` tuple as returned by 198 `sys.exc_info()`. If this is the first call to `request_stop()` the 199 corresponding exception is recorded and re-raised from `join()`. 200 """ 201 with self._lock: 202 ex = self._filter_exception(ex) 203 # If we have already joined the coordinator the exception will not have a 204 # chance to be reported, so just raise it normally. This can happen if 205 # you continue to use a session have having stopped and joined the 206 # coordinator threads. 207 if self._joined: 208 if isinstance(ex, tuple): 209 six.reraise(*ex) 210 elif ex is not None: 211 # NOTE(touts): This is bogus if request_stop() is not called 212 # from the exception handler that raised ex. 213 six.reraise(*sys.exc_info()) 214 if not self._stop_event.is_set(): 215 if ex and self._exc_info_to_raise is None: 216 if isinstance(ex, tuple): 217 logging.info("Error reported to Coordinator: %s", 218 compat.as_str_any(ex[1]), 219 exc_info=ex) 220 self._exc_info_to_raise = ex 221 else: 222 logging.info("Error reported to Coordinator: %s, %s", 223 type(ex), 224 compat.as_str_any(ex)) 225 self._exc_info_to_raise = sys.exc_info() 226 # self._exc_info_to_raise should contain a tuple containing exception 227 # (type, value, traceback) 228 if (len(self._exc_info_to_raise) != 3 or 229 not self._exc_info_to_raise[0] or 230 not self._exc_info_to_raise[1]): 231 # Raise, catch and record the exception here so that error happens 232 # where expected. 233 try: 234 raise ValueError( 235 "ex must be a tuple or sys.exc_info must return the current " 236 "exception: %s" 237 % self._exc_info_to_raise) 238 except ValueError: 239 # Record this error so it kills the coordinator properly. 240 # NOTE(touts): As above, this is bogus if request_stop() is not 241 # called from the exception handler that raised ex. 242 self._exc_info_to_raise = sys.exc_info() 243 244 self._stop_event.set() 245 246 def clear_stop(self): 247 """Clears the stop flag. 248 249 After this is called, calls to `should_stop()` will return `False`. 250 """ 251 with self._lock: 252 self._joined = False 253 self._exc_info_to_raise = None 254 if self._stop_event.is_set(): 255 self._stop_event.clear() 256 257 def should_stop(self): 258 """Check if stop was requested. 259 260 Returns: 261 True if a stop was requested. 262 """ 263 return self._stop_event.is_set() 264 265 @contextlib.contextmanager 266 def stop_on_exception(self): 267 """Context manager to request stop when an Exception is raised. 268 269 Code that uses a coordinator must catch exceptions and pass 270 them to the `request_stop()` method to stop the other threads 271 managed by the coordinator. 272 273 This context handler simplifies the exception handling. 274 Use it as follows: 275 276 ```python 277 with coord.stop_on_exception(): 278 # Any exception raised in the body of the with 279 # clause is reported to the coordinator before terminating 280 # the execution of the body. 281 ...body... 282 ``` 283 284 This is completely equivalent to the slightly longer code: 285 286 ```python 287 try: 288 ...body... 289 except: 290 coord.request_stop(sys.exc_info()) 291 ``` 292 293 Yields: 294 nothing. 295 """ 296 try: 297 yield 298 except: # pylint: disable=bare-except 299 self.request_stop(ex=sys.exc_info()) 300 301 def wait_for_stop(self, timeout=None): 302 """Wait till the Coordinator is told to stop. 303 304 Args: 305 timeout: Float. Sleep for up to that many seconds waiting for 306 should_stop() to become True. 307 308 Returns: 309 True if the Coordinator is told stop, False if the timeout expired. 310 """ 311 return self._stop_event.wait(timeout) 312 313 def register_thread(self, thread): 314 """Register a thread to join. 315 316 Args: 317 thread: A Python thread to join. 318 """ 319 with self._lock: 320 self._registered_threads.add(thread) 321 322 def join(self, threads=None, stop_grace_period_secs=120, 323 ignore_live_threads=False): 324 """Wait for threads to terminate. 325 326 This call blocks until a set of threads have terminated. The set of thread 327 is the union of the threads passed in the `threads` argument and the list 328 of threads that registered with the coordinator by calling 329 `Coordinator.register_thread()`. 330 331 After the threads stop, if an `exc_info` was passed to `request_stop`, that 332 exception is re-raised. 333 334 Grace period handling: When `request_stop()` is called, threads are given 335 'stop_grace_period_secs' seconds to terminate. If any of them is still 336 alive after that period expires, a `RuntimeError` is raised. Note that if 337 an `exc_info` was passed to `request_stop()` then it is raised instead of 338 that `RuntimeError`. 339 340 Args: 341 threads: List of `threading.Threads`. The started threads to join in 342 addition to the registered threads. 343 stop_grace_period_secs: Number of seconds given to threads to stop after 344 `request_stop()` has been called. 345 ignore_live_threads: If `False`, raises an error if any of the threads are 346 still alive after `stop_grace_period_secs`. 347 348 Raises: 349 RuntimeError: If any thread is still alive after `request_stop()` 350 is called and the grace period expires. 351 """ 352 # Threads registered after this call will not be joined. 353 with self._lock: 354 if threads is None: 355 threads = self._registered_threads 356 else: 357 threads = self._registered_threads.union(set(threads)) 358 # Copy the set into a list to avoid race conditions where a new thread 359 # is added while we are waiting. 360 threads = list(threads) 361 362 # Wait for all threads to stop or for request_stop() to be called. 363 while any(t.is_alive() for t in threads) and not self.wait_for_stop(1.0): 364 pass 365 366 # If any thread is still alive, wait for the grace period to expire. 367 # By the time this check is executed, threads may still be shutting down, 368 # so we add a sleep of increasing duration to give them a chance to shut 369 # down without losing too many cycles. 370 # The sleep duration is limited to the remaining grace duration. 371 stop_wait_secs = 0.001 372 while any(t.is_alive() for t in threads) and stop_grace_period_secs >= 0.0: 373 time.sleep(stop_wait_secs) 374 stop_grace_period_secs -= stop_wait_secs 375 stop_wait_secs = 2 * stop_wait_secs 376 # Keep the waiting period within sane bounds. 377 # The minimum value is to avoid decreasing stop_wait_secs to a value 378 # that could cause stop_grace_period_secs to remain unchanged. 379 stop_wait_secs = max(min(stop_wait_secs, stop_grace_period_secs), 0.001) 380 381 # List the threads still alive after the grace period. 382 stragglers = [t.name for t in threads if t.is_alive()] 383 384 # Terminate with an exception if appropriate. 385 with self._lock: 386 self._joined = True 387 self._registered_threads = set() 388 if self._exc_info_to_raise: 389 six.reraise(*self._exc_info_to_raise) 390 elif stragglers: 391 if ignore_live_threads: 392 logging.info("Coordinator stopped with threads still running: %s", 393 " ".join(stragglers)) 394 else: 395 raise RuntimeError( 396 "Coordinator stopped with threads still running: %s" % 397 " ".join(stragglers)) 398 399 @property 400 def joined(self): 401 return self._joined 402 403 def raise_requested_exception(self): 404 """If an exception has been passed to `request_stop`, this raises it.""" 405 with self._lock: 406 if self._exc_info_to_raise: 407 six.reraise(*self._exc_info_to_raise) 408 409 410# Threads for the standard services. 411@tf_export(v1=["train.LooperThread"]) 412class LooperThread(threading.Thread): 413 """A thread that runs code repeatedly, optionally on a timer. 414 415 This thread class is intended to be used with a `Coordinator`. It repeatedly 416 runs code specified either as `target` and `args` or by the `run_loop()` 417 method. 418 419 Before each run the thread checks if the coordinator has requested stop. In 420 that case the looper thread terminates immediately. 421 422 If the code being run raises an exception, that exception is reported to the 423 coordinator and the thread terminates. The coordinator will then request all 424 the other threads it coordinates to stop. 425 426 You typically pass looper threads to the supervisor `Join()` method. 427 """ 428 429 def __init__(self, coord, timer_interval_secs, target=None, args=None, 430 kwargs=None): 431 """Create a LooperThread. 432 433 Args: 434 coord: A Coordinator. 435 timer_interval_secs: Time boundaries at which to call Run(), or None 436 if it should be called back to back. 437 target: Optional callable object that will be executed in the thread. 438 args: Optional arguments to pass to `target` when calling it. 439 kwargs: Optional keyword arguments to pass to `target` when calling it. 440 441 Raises: 442 ValueError: If one of the arguments is invalid. 443 """ 444 if not isinstance(coord, Coordinator): 445 raise ValueError("'coord' argument must be a Coordinator: %s" % coord) 446 super(LooperThread, self).__init__() 447 self.daemon = True 448 self._coord = coord 449 self._timer_interval_secs = timer_interval_secs 450 self._target = target 451 if self._target: 452 self._args = args or () 453 self._kwargs = kwargs or {} 454 elif args or kwargs: 455 raise ValueError("'args' and 'kwargs' argument require that you also " 456 "pass 'target'") 457 self._coord.register_thread(self) 458 459 @staticmethod 460 def loop(coord, timer_interval_secs, target, args=None, kwargs=None): 461 """Start a LooperThread that calls a function periodically. 462 463 If `timer_interval_secs` is None the thread calls `target(args)` 464 repeatedly. Otherwise `target(args)` is called every `timer_interval_secs` 465 seconds. The thread terminates when a stop of the coordinator is 466 requested. 467 468 Args: 469 coord: A Coordinator. 470 timer_interval_secs: Number. Time boundaries at which to call `target`. 471 target: A callable object. 472 args: Optional arguments to pass to `target` when calling it. 473 kwargs: Optional keyword arguments to pass to `target` when calling it. 474 475 Returns: 476 The started thread. 477 """ 478 looper = LooperThread(coord, timer_interval_secs, target=target, args=args, 479 kwargs=kwargs) 480 looper.start() 481 return looper 482 483 def run(self): 484 with self._coord.stop_on_exception(): 485 self.start_loop() 486 if self._timer_interval_secs is None: 487 # Call back-to-back. 488 while not self._coord.should_stop(): 489 self.run_loop() 490 else: 491 # Next time at which to call run_loop(), starts as 'now'. 492 next_timer_time = time.time() 493 while not self._coord.wait_for_stop(next_timer_time - time.time()): 494 next_timer_time += self._timer_interval_secs 495 self.run_loop() 496 self.stop_loop() 497 498 def start_loop(self): 499 """Called when the thread starts.""" 500 pass 501 502 def stop_loop(self): 503 """Called when the thread stops.""" 504 pass 505 506 def run_loop(self): 507 """Called at 'timer_interval_secs' boundaries.""" 508 if self._target: 509 self._target(*self._args, **self._kwargs) 510