1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# =================================================================== 15"""ErrorRendezvous handler for collecting errors from multiple threads.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import contextlib 22import sys 23import threading 24import time 25 26import six 27 28from tensorflow.python.framework import errors 29from tensorflow.python.platform import tf_logging as logging 30 31_UNINTERESTING_ERRORS = (errors.CancelledError,) 32 33 34class ErrorRendezvous(object): 35 """Resolve errors from multiple threads during TPU execution. 36 37 TPU errors can occur on the infeed or outfeed threads as well as the main 38 training thread. 39 40 Depending on which thread "wins" and receives the session error first, we may 41 end up showing users a confusing and non-actionable error message (session 42 cancelled) instead of a root cause (e.g. a bad filename). 43 44 The rendezvous object provides a location to capture these errors until all 45 threads terminate. At that point we can choose the most informative error 46 to report. 47 """ 48 49 def __init__(self, num_sources): 50 # string -> (message, traceback) 51 self._errors = {} 52 self._num_sources = num_sources 53 self._session_cancel_timer = None 54 55 def record_error(self, source, exc_info, session=None): 56 """Report an exception from the given source. 57 58 If a session is passed, a timer will be registered to close it after a few 59 seconds. This is necessary to ensure the main training loop does not hang 60 if an infeed/oufeed error occurs. We sleep a few seconds to allow a more 61 interesting error from another thread to propagate. 62 63 Args: 64 source: string, source of the error 65 exc_info: Output from `sys.exc_info` (type, value, traceback) 66 session: Session to close after delay. 67 """ 68 _, value, _ = exc_info 69 self._errors[source] = exc_info 70 logging.info('Error recorded from %s: %s', source, value) 71 72 if session is not None and self._session_cancel_timer is None: 73 74 def _cancel_session(): 75 time.sleep(5) 76 try: 77 session.close() 78 except: # pylint: disable=bare-except 79 pass 80 81 self._session_cancel_timer = threading.Thread(target=_cancel_session,) 82 self._session_cancel_timer.daemon = True 83 self._session_cancel_timer.start() 84 85 def record_done(self, source): 86 """Mark execution source `source` as done. 87 88 If an error was originally reported from `source` it is left intact. 89 90 Args: 91 source: `str`, source being recorded 92 """ 93 logging.info('%s marked as finished', source) 94 if source not in self._errors: 95 self._errors[source] = None 96 97 @contextlib.contextmanager 98 def catch_errors(self, source, session=None): 99 """Context manager to report any errors within a block.""" 100 try: 101 yield 102 except Exception: # pylint: disable=broad-except 103 self.record_error(source, sys.exc_info(), session) 104 105 def raise_errors(self, timeout_sec=0): 106 """Wait for up to `timeout` seconds for all error sources to finish. 107 108 Preferentially raise "interesting" errors (errors not in the 109 _UNINTERESTING_ERRORS) set. 110 111 Args: 112 timeout_sec: Seconds to wait for other error sources. 113 """ 114 for _ in range(timeout_sec): 115 if len(self._errors) == self._num_sources: 116 break 117 time.sleep(1) 118 119 kept_errors = [(k, v) for (k, v) in self._errors.items() if v is not None] 120 121 # First check for any interesting errors, then fall back on the session 122 # cancelled errors etc. 123 for k, (typ, value, traceback) in kept_errors: 124 if isinstance(value, _UNINTERESTING_ERRORS): 125 continue 126 else: 127 logging.warn('Reraising captured error') 128 six.reraise(typ, value, traceback) 129 130 for k, (typ, value, traceback) in kept_errors: 131 logging.warn('Reraising captured error') 132 six.reraise(typ, value, traceback) 133