• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Execution Callbacks for Eager Mode."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import contextlib
22import functools
23import enum  # pylint: disable=g-bad-import-order
24
25import numpy as np
26
27from tensorflow.python import pywrap_tensorflow
28from tensorflow.python.eager import context
29from tensorflow.python.eager import core
30from tensorflow.python.eager import execute
31from tensorflow.python.platform import tf_logging as logging
32
33
34class ExecutionCallback(enum.Enum):
35  """Valid callback actions.
36
37  These can be passed to `seterr` or `errstate` to create callbacks when
38  specific events occur (e.g. an operation produces `NaN`s).
39
40  IGNORE: take no action.
41  PRINT:  print a warning to `stdout`.
42  RAISE:  raise an error (e.g. `InfOrNanError`).
43  WARN:   print a warning using `tf.logging.warn`.
44  """
45
46  IGNORE = "ignore"
47  PRINT = "print"
48  RAISE = "raise"
49  WARN = "warn"
50
51_DEFAULT_CALLBACK_ACTION = ExecutionCallback.RAISE
52
53
54# TODO(cais): Consider moving this exception class to errors_impl.py.
55class InfOrNanError(Exception):
56  """Exception for inf and/or nan being present in tensor."""
57
58  def __init__(self,
59               op_type,
60               op_name,
61               output_index,
62               num_outputs,
63               value):
64    """Constructor of InfOrNanError.
65
66    Args:
67      op_type: Type name of the op that generated the tensor with
68        `inf`(s) or `nan`(s) (e.g., `Div`).
69      op_name: Name of the op that generated the tensor with `inf`(s) or
70        `nan`(s). This name is set by client and can be `None` if it is unset.
71      output_index: The 0-based output index of the tensor that contains
72        `inf`(s) or `nan`(s).
73      num_outputs: Total number of outputs of the operation.
74      value: The tensor value that contains `inf`(s) or `nan`(s).
75    """
76    self._op_type = op_type
77    self._op_name = op_name
78    self._output_index = output_index
79    self._num_outputs = num_outputs
80    self._value = value
81
82    self._total_count = np.size(value)
83    self._inf_count = np.count_nonzero(np.isinf(value))
84    self._nan_count = np.count_nonzero(np.isnan(value))
85
86    super(InfOrNanError, self).__init__(self._get_error_message())
87
88  def _get_error_message(self):
89    """Get the error message describing this InfOrNanError object."""
90    name_str = (("'%s'" % self._op_name) if self._op_name is not None
91                else str(self._op_name))
92    msg = "Output %d of %d of TFE operation %s (name: %s) contains " % (
93        self._output_index + 1, self._num_outputs, self._op_type, name_str)
94    if self._inf_count and self._nan_count:
95      msg += "%d inf(s) and %d nan(s) " % (self._inf_count, self._nan_count)
96    elif self._inf_count:
97      msg += "%d inf(s) " % self._inf_count
98    else:
99      msg += "%d nan(s) " % self._nan_count
100    msg += "out of a total of %d element(s). Tensor value: %s" % (
101        self._total_count, self._value)
102    return msg
103
104  @property
105  def op_type(self):
106    return self._op_type
107
108  @property
109  def op_name(self):
110    return self._op_name
111
112  @property
113  def output_index(self):
114    return self._output_index
115
116  @property
117  def num_outputs(self):
118    return self._num_outputs
119
120  @property
121  def value(self):
122    return self._value
123
124
125def inf_nan_callback(op_type,
126                     inputs,
127                     attrs,
128                     outputs,
129                     op_name,
130                     check_inf=True,
131                     check_nan=True,
132                     action=_DEFAULT_CALLBACK_ACTION):
133  """An execution callback that checks for `inf`s and `nan`s in output tensors.
134
135  This callback can be used with `tfe.add_execute_callback` to check for invalid
136  numeric values. E.g.,
137  ```python
138  tfe.add_execute_callback(tfe.inf_nan_callback)
139  ```
140
141  Args:
142    op_type: Name of the TFE operation type (e.g., `MatMul`).
143    inputs: The `list` of input tensors to the operation, currently unused by
144      this callback.
145    attrs: Attributes of the TFE operation, as a tuple of alternating attribute
146      names and attribute values.
147    outputs: The `list` of output tensors from the operation, checked by this
148      callback for `inf` and `nan` values.
149    op_name: Name of the TFE operation. This name is set by client and can be
150      `None` if it unset.
151    check_inf: (`bool`) Whether this callback should check for `inf` values in
152      the output tensor values.
153    check_nan: (`bool`) Whether this callback should check for `nan` values in
154      the output tensor values.
155    action: (`ExecutionCallback`) Action to be taken by the callback when
156      `inf` or `nan` values are detected.
157
158  Raises:
159    InfOrNanError: iff `inf` or `nan` values are seen in any of `outputs` and
160      `action` is `"raise"`.
161    ValueError: iff the value of `action` is invalid.
162  """
163  del attrs, inputs  # Not used.
164
165  action = ExecutionCallback(action)
166  ctx = context.context()
167
168  for index, output in enumerate(outputs):
169    if not output.dtype.is_numpy_compatible:
170      continue
171
172    numpy_dtype = output.dtype.as_numpy_dtype
173    if (np.issubdtype(numpy_dtype, np.floating) or
174        np.issubdtype(numpy_dtype, np.complex) or
175        np.issubdtype(numpy_dtype, np.integer)):
176      try:
177        check_numerics_op_attrs = (
178            "message", "Eager-mode inf/nan check",
179            "T", outputs[0].dtype.as_datatype_enum)
180        # TODO(cais): Consider moving this into execute.py.
181        # pylint: disable=protected-access
182        pywrap_tensorflow.TFE_Py_Execute(
183            ctx._handle, output.device, "CheckNumerics", [output],
184            check_numerics_op_attrs, 1)
185        # pylint: enable=protected-access
186      except core._NotOkStatusException:  # pylint: disable=protected-access
187        value = output.numpy()
188        inf_detected = np.any(np.isinf(value)) and check_inf
189        nan_detected = np.any(np.isnan(value)) and check_nan
190        if not inf_detected and not nan_detected:
191          continue
192
193        error = InfOrNanError(op_type, op_name, index, len(outputs), value)
194        if action == ExecutionCallback.PRINT:
195          print("Warning: %s" % str(error))
196        elif action == ExecutionCallback.WARN:
197          logging.warn(str(error))
198        elif action == ExecutionCallback.RAISE:
199          raise error
200        else:
201          raise ValueError(
202              "Invalid action for inf_nan_callback: %s. Valid actions are: "
203              "{PRINT | WARN | RAISE}" % action)
204
205
206def inf_callback(op_type,
207                 inputs,
208                 attrs,
209                 outputs,
210                 op_name,
211                 action=_DEFAULT_CALLBACK_ACTION):
212  """A specialization of `inf_nan_callback` that checks for `inf`s only."""
213  inf_nan_callback(
214      op_type,
215      inputs,
216      attrs,
217      outputs,
218      op_name,
219      check_inf=True,
220      check_nan=False,
221      action=action)
222
223
224def nan_callback(op_type,
225                 inputs,
226                 attrs,
227                 outputs,
228                 op_name,
229                 action=_DEFAULT_CALLBACK_ACTION):
230  """A specialization of `inf_nan_callback` that checks for `nan`s only."""
231  inf_nan_callback(
232      op_type,
233      inputs,
234      attrs,
235      outputs,
236      op_name,
237      check_inf=False,
238      check_nan=True,
239      action=action)
240
241
242def add_execution_callback(callback):
243  """Add an execution callback to the default eager context.
244
245  An execution callback is invoked immediately after an eager operation or
246  function has finished execution, providing access to the op's type, name
247  input and output tensors. Multiple execution callbacks can be added, in
248  which case the callbacks will be invoked in the order in which they are
249  added. To clear all execution callbacks that have been added, use
250  `clear_execution_callbacks()`.
251
252  Example:
253  ```python
254  def print_even_callback(op_type, op_name, attrs, inputs, outputs):
255    # A callback that prints only the even output values.
256    if outputs[0].numpy() % 2 == 0:
257      print("Even output from %s: %s" % (op_name or op_type,  outputs))
258  tfe.add_execution_callback(print_even_callback)
259
260  x = tf.pow(2.0, 3.0) - 3.0
261  y = tf.multiply(x, tf.add(1.0, 5.0))
262  # When the line above is run, you will see all intermediate outputs that are
263  # even numbers printed to the console.
264
265  tfe.clear_execution_callbacks()
266  ```
267
268  Args:
269    callback: a callable of the signature
270      `f(op_type, op_name, attrs, inputs, outputs)`.
271      `op_type` is the type of the operation that was just executed (e.g.,
272        `MatMul`).
273      `op_name` is the name of the operation that was just executed. This
274        name is set by the client who created the operation and can be `None` if
275        it is unset.
276      `attrs` contains the attributes of the operation as a `tuple` of
277        alternating attribute name and attribute value.
278      `inputs` is the `list` of input `Tensor`(s) to the op.
279      `outputs` is the `list` of output `Tensor`(s) from the op.
280       Return value(s) from the callback are ignored.
281  """
282  execute.execute = execute.execute_with_callbacks
283  context.context().add_post_execution_callback(callback)
284
285
286def clear_execution_callbacks():
287  """Clear all execution callbacks from the default eager context."""
288  context.context().clear_post_execution_callbacks()
289
290
291def seterr(inf_or_nan=None):
292  """Set how abnormal conditions are handled by the default eager context.
293
294  Example:
295  ```python
296  tfe.seterr(inf_or_nan=ExecutionCallback.RAISE)
297  a = tf.constant(10.0)
298  b = tf.constant(0.0)
299  try:
300    c = a / b  # <-- Raises InfOrNanError.
301  except Exception as e:
302    print("Caught Exception: %s" % e)
303
304  tfe.seterr(inf_or_nan=ExecutionCallback.IGNORE)
305  c = a / b  # <-- Does NOT raise exception anymore.
306  ```
307
308  Args:
309    inf_or_nan: An `ExecutionCallback` determining the action for infinity
310      (`inf`) and NaN (`nan`) values. A value of `None` leads to no change in
311      the action of the condition.
312
313  Returns:
314    A dictionary of old actions.
315
316  Raises:
317    ValueError: If the value of any keyword arguments is invalid.
318  """
319  inf_or_nan = ExecutionCallback(inf_or_nan) if inf_or_nan is not None else None
320  old_settings = {"inf_or_nan": ExecutionCallback.IGNORE}
321  default_context = context.context()
322
323  carryover_callbacks = []
324  for callback in default_context.post_execution_callbacks:
325    # Check whether the callback is inf_nan_callback or a partial object of
326    # inf_nan_callback.
327    if (callback == inf_nan_callback or
328        isinstance(callback, functools.partial) and
329        callback.func == inf_nan_callback):
330      if callback == inf_nan_callback:
331        old_settings["inf_or_nan"] = _DEFAULT_CALLBACK_ACTION
332      else:
333        old_settings["inf_or_nan"] = callback.keywords.get(
334            "action", _DEFAULT_CALLBACK_ACTION)
335    elif inf_or_nan is not None:
336      carryover_callbacks.append(callback)
337
338  if inf_or_nan is not None:
339    default_context.clear_post_execution_callbacks()
340    for callback in carryover_callbacks:
341      default_context.add_post_execution_callback(callback)
342    if inf_or_nan != ExecutionCallback.IGNORE:
343      default_context.add_post_execution_callback(
344          functools.partial(inf_nan_callback, action=inf_or_nan))
345
346  return old_settings
347
348
349@contextlib.contextmanager
350def errstate(inf_or_nan=None):
351  """Context manager setting error state.
352
353  Example:
354  ```
355  c = tf.log(0.)  # -inf
356
357  with errstate(inf_or_nan=ExecutionCallback.RAISE):
358    tf.log(0.)  # <-- Raises InfOrNanError.
359  ```
360
361  Args:
362    inf_or_nan: An `ExecutionCallback` determining the action for infinity
363      (`inf`) and NaN (`nan`) values. A value of `None` leads to no change in
364      the action of the condition.
365
366  Yields:
367    None.
368
369  Raises:
370    ValueError: If the value of any keyword arguments is invalid.
371  """
372  if not context.executing_eagerly():
373    yield
374  else:
375    old_settings = seterr(inf_or_nan=inf_or_nan)
376    yield
377    seterr(**old_settings)
378