• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ===================================================================
15"""ErrorRendezvous handler for collecting errors from multiple threads."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import contextlib
22import sys
23import threading
24import time
25
26import six
27
28from tensorflow.python.framework import errors
29from tensorflow.python.platform import tf_logging as logging
30
31_UNINTERESTING_ERRORS = (errors.CancelledError,)
32
33
34class ErrorRendezvous(object):
35  """Resolve errors from multiple threads during TPU execution.
36
37  TPU errors can occur on the infeed or outfeed threads as well as the main
38  training thread.
39
40  Depending on which thread "wins" and receives the session error first, we may
41  end up showing users a confusing and non-actionable error message (session
42  cancelled) instead of a root cause (e.g. a bad filename).
43
44  The rendezvous object provides a location to capture these errors until all
45  threads terminate.  At that point we can choose the most informative error
46  to report.
47  """
48
49  def __init__(self, num_sources):
50    # string -> (message, traceback)
51    self._errors = {}
52    self._num_sources = num_sources
53    self._session_cancel_timer = None
54
55  def record_error(self, source, exc_info, session=None):
56    """Report an exception from the given source.
57
58    If a session is passed, a timer will be registered to close it after a few
59    seconds.  This is necessary to ensure the main training loop does not hang
60    if an infeed/oufeed error occurs.  We sleep a few seconds to allow a more
61    interesting error from another thread to propagate.
62
63    Args:
64      source: string, source of the error
65      exc_info: Output from `sys.exc_info` (type, value, traceback)
66      session: Session to close after delay.
67    """
68    _, value, _ = exc_info
69    self._errors[source] = exc_info
70    logging.info('Error recorded from %s: %s', source, value)
71
72    if session is not None and self._session_cancel_timer is None:
73
74      def _cancel_session():
75        time.sleep(5)
76        try:
77          session.close()
78        except:  # pylint: disable=bare-except
79          pass
80
81      self._session_cancel_timer = threading.Thread(target=_cancel_session,)
82      self._session_cancel_timer.daemon = True
83      self._session_cancel_timer.start()
84
85  def record_done(self, source):
86    """Mark execution source `source` as done.
87
88    If an error was originally reported from `source` it is left intact.
89
90    Args:
91      source: `str`, source being recorded
92    """
93    logging.info('%s marked as finished', source)
94    if source not in self._errors:
95      self._errors[source] = None
96
97  @contextlib.contextmanager
98  def catch_errors(self, source, session=None):
99    """Context manager to report any errors within a block."""
100    try:
101      yield
102    except Exception:  # pylint: disable=broad-except
103      self.record_error(source, sys.exc_info(), session)
104
105  def raise_errors(self, timeout_sec=0):
106    """Wait for up to `timeout` seconds for all error sources to finish.
107
108    Preferentially raise "interesting" errors (errors not in the
109    _UNINTERESTING_ERRORS) set.
110
111    Args:
112      timeout_sec: Seconds to wait for other error sources.
113    """
114    for _ in range(timeout_sec):
115      if len(self._errors) == self._num_sources:
116        break
117      time.sleep(1)
118
119    kept_errors = [(k, v) for (k, v) in self._errors.items() if v is not None]
120
121    # First check for any interesting errors, then fall back on the session
122    # cancelled errors etc.
123    for k, (typ, value, traceback) in kept_errors:
124      if isinstance(value, _UNINTERESTING_ERRORS):
125        continue
126      else:
127        logging.warn('Reraising captured error')
128        six.reraise(typ, value, traceback)
129
130    for k, (typ, value, traceback) in kept_errors:
131      logging.warn('Reraising captured error')
132      six.reraise(typ, value, traceback)
133