1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Execution Callbacks for Eager Mode.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import contextlib 22import functools 23import enum # pylint: disable=g-bad-import-order 24 25import numpy as np 26 27from tensorflow.python import pywrap_tensorflow 28from tensorflow.python.eager import context 29from tensorflow.python.eager import core 30from tensorflow.python.eager import execute 31from tensorflow.python.platform import tf_logging as logging 32 33 34class ExecutionCallback(enum.Enum): 35 """Valid callback actions. 36 37 These can be passed to `seterr` or `errstate` to create callbacks when 38 specific events occur (e.g. an operation produces `NaN`s). 39 40 IGNORE: take no action. 41 PRINT: print a warning to `stdout`. 42 RAISE: raise an error (e.g. `InfOrNanError`). 43 WARN: print a warning using `tf.logging.warn`. 44 """ 45 46 IGNORE = "ignore" 47 PRINT = "print" 48 RAISE = "raise" 49 WARN = "warn" 50 51_DEFAULT_CALLBACK_ACTION = ExecutionCallback.RAISE 52 53 54# TODO(cais): Consider moving this exception class to errors_impl.py. 55class InfOrNanError(Exception): 56 """Exception for inf and/or nan being present in tensor.""" 57 58 def __init__(self, 59 op_type, 60 op_name, 61 output_index, 62 num_outputs, 63 value): 64 """Constructor of InfOrNanError. 65 66 Args: 67 op_type: Type name of the op that generated the tensor with 68 `inf`(s) or `nan`(s) (e.g., `Div`). 69 op_name: Name of the op that generated the tensor with `inf`(s) or 70 `nan`(s). This name is set by client and can be `None` if it is unset. 71 output_index: The 0-based output index of the tensor that contains 72 `inf`(s) or `nan`(s). 73 num_outputs: Total number of outputs of the operation. 74 value: The tensor value that contains `inf`(s) or `nan`(s). 75 """ 76 self._op_type = op_type 77 self._op_name = op_name 78 self._output_index = output_index 79 self._num_outputs = num_outputs 80 self._value = value 81 82 self._total_count = np.size(value) 83 self._inf_count = np.count_nonzero(np.isinf(value)) 84 self._nan_count = np.count_nonzero(np.isnan(value)) 85 86 super(InfOrNanError, self).__init__(self._get_error_message()) 87 88 def _get_error_message(self): 89 """Get the error message describing this InfOrNanError object.""" 90 name_str = (("'%s'" % self._op_name) if self._op_name is not None 91 else str(self._op_name)) 92 msg = "Output %d of %d of TFE operation %s (name: %s) contains " % ( 93 self._output_index + 1, self._num_outputs, self._op_type, name_str) 94 if self._inf_count and self._nan_count: 95 msg += "%d inf(s) and %d nan(s) " % (self._inf_count, self._nan_count) 96 elif self._inf_count: 97 msg += "%d inf(s) " % self._inf_count 98 else: 99 msg += "%d nan(s) " % self._nan_count 100 msg += "out of a total of %d element(s). Tensor value: %s" % ( 101 self._total_count, self._value) 102 return msg 103 104 @property 105 def op_type(self): 106 return self._op_type 107 108 @property 109 def op_name(self): 110 return self._op_name 111 112 @property 113 def output_index(self): 114 return self._output_index 115 116 @property 117 def num_outputs(self): 118 return self._num_outputs 119 120 @property 121 def value(self): 122 return self._value 123 124 125def inf_nan_callback(op_type, 126 inputs, 127 attrs, 128 outputs, 129 op_name, 130 check_inf=True, 131 check_nan=True, 132 action=_DEFAULT_CALLBACK_ACTION): 133 """An execution callback that checks for `inf`s and `nan`s in output tensors. 134 135 This callback can be used with `tfe.add_execute_callback` to check for invalid 136 numeric values. E.g., 137 ```python 138 tfe.add_execute_callback(tfe.inf_nan_callback) 139 ``` 140 141 Args: 142 op_type: Name of the TFE operation type (e.g., `MatMul`). 143 inputs: The `list` of input tensors to the operation, currently unused by 144 this callback. 145 attrs: Attributes of the TFE operation, as a tuple of alternating attribute 146 names and attribute values. 147 outputs: The `list` of output tensors from the operation, checked by this 148 callback for `inf` and `nan` values. 149 op_name: Name of the TFE operation. This name is set by client and can be 150 `None` if it unset. 151 check_inf: (`bool`) Whether this callback should check for `inf` values in 152 the output tensor values. 153 check_nan: (`bool`) Whether this callback should check for `nan` values in 154 the output tensor values. 155 action: (`ExecutionCallback`) Action to be taken by the callback when 156 `inf` or `nan` values are detected. 157 158 Raises: 159 InfOrNanError: iff `inf` or `nan` values are seen in any of `outputs` and 160 `action` is `"raise"`. 161 ValueError: iff the value of `action` is invalid. 162 """ 163 del attrs, inputs # Not used. 164 165 action = ExecutionCallback(action) 166 ctx = context.context() 167 168 for index, output in enumerate(outputs): 169 if not output.dtype.is_numpy_compatible: 170 continue 171 172 numpy_dtype = output.dtype.as_numpy_dtype 173 if (np.issubdtype(numpy_dtype, np.floating) or 174 np.issubdtype(numpy_dtype, np.complex) or 175 np.issubdtype(numpy_dtype, np.integer)): 176 try: 177 check_numerics_op_attrs = ( 178 "message", "Eager-mode inf/nan check", 179 "T", outputs[0].dtype.as_datatype_enum) 180 # TODO(cais): Consider moving this into execute.py. 181 # pylint: disable=protected-access 182 pywrap_tensorflow.TFE_Py_Execute( 183 ctx._handle, output.device, "CheckNumerics", [output], 184 check_numerics_op_attrs, 1) 185 # pylint: enable=protected-access 186 except core._NotOkStatusException: # pylint: disable=protected-access 187 value = output.numpy() 188 inf_detected = np.any(np.isinf(value)) and check_inf 189 nan_detected = np.any(np.isnan(value)) and check_nan 190 if not inf_detected and not nan_detected: 191 continue 192 193 error = InfOrNanError(op_type, op_name, index, len(outputs), value) 194 if action == ExecutionCallback.PRINT: 195 print("Warning: %s" % str(error)) 196 elif action == ExecutionCallback.WARN: 197 logging.warn(str(error)) 198 elif action == ExecutionCallback.RAISE: 199 raise error 200 else: 201 raise ValueError( 202 "Invalid action for inf_nan_callback: %s. Valid actions are: " 203 "{PRINT | WARN | RAISE}" % action) 204 205 206def inf_callback(op_type, 207 inputs, 208 attrs, 209 outputs, 210 op_name, 211 action=_DEFAULT_CALLBACK_ACTION): 212 """A specialization of `inf_nan_callback` that checks for `inf`s only.""" 213 inf_nan_callback( 214 op_type, 215 inputs, 216 attrs, 217 outputs, 218 op_name, 219 check_inf=True, 220 check_nan=False, 221 action=action) 222 223 224def nan_callback(op_type, 225 inputs, 226 attrs, 227 outputs, 228 op_name, 229 action=_DEFAULT_CALLBACK_ACTION): 230 """A specialization of `inf_nan_callback` that checks for `nan`s only.""" 231 inf_nan_callback( 232 op_type, 233 inputs, 234 attrs, 235 outputs, 236 op_name, 237 check_inf=False, 238 check_nan=True, 239 action=action) 240 241 242def add_execution_callback(callback): 243 """Add an execution callback to the default eager context. 244 245 An execution callback is invoked immediately after an eager operation or 246 function has finished execution, providing access to the op's type, name 247 input and output tensors. Multiple execution callbacks can be added, in 248 which case the callbacks will be invoked in the order in which they are 249 added. To clear all execution callbacks that have been added, use 250 `clear_execution_callbacks()`. 251 252 Example: 253 ```python 254 def print_even_callback(op_type, op_name, attrs, inputs, outputs): 255 # A callback that prints only the even output values. 256 if outputs[0].numpy() % 2 == 0: 257 print("Even output from %s: %s" % (op_name or op_type, outputs)) 258 tfe.add_execution_callback(print_even_callback) 259 260 x = tf.pow(2.0, 3.0) - 3.0 261 y = tf.multiply(x, tf.add(1.0, 5.0)) 262 # When the line above is run, you will see all intermediate outputs that are 263 # even numbers printed to the console. 264 265 tfe.clear_execution_callbacks() 266 ``` 267 268 Args: 269 callback: a callable of the signature 270 `f(op_type, op_name, attrs, inputs, outputs)`. 271 `op_type` is the type of the operation that was just executed (e.g., 272 `MatMul`). 273 `op_name` is the name of the operation that was just executed. This 274 name is set by the client who created the operation and can be `None` if 275 it is unset. 276 `attrs` contains the attributes of the operation as a `tuple` of 277 alternating attribute name and attribute value. 278 `inputs` is the `list` of input `Tensor`(s) to the op. 279 `outputs` is the `list` of output `Tensor`(s) from the op. 280 Return value(s) from the callback are ignored. 281 """ 282 execute.execute = execute.execute_with_callbacks 283 context.context().add_post_execution_callback(callback) 284 285 286def clear_execution_callbacks(): 287 """Clear all execution callbacks from the default eager context.""" 288 context.context().clear_post_execution_callbacks() 289 290 291def seterr(inf_or_nan=None): 292 """Set how abnormal conditions are handled by the default eager context. 293 294 Example: 295 ```python 296 tfe.seterr(inf_or_nan=ExecutionCallback.RAISE) 297 a = tf.constant(10.0) 298 b = tf.constant(0.0) 299 try: 300 c = a / b # <-- Raises InfOrNanError. 301 except Exception as e: 302 print("Caught Exception: %s" % e) 303 304 tfe.seterr(inf_or_nan=ExecutionCallback.IGNORE) 305 c = a / b # <-- Does NOT raise exception anymore. 306 ``` 307 308 Args: 309 inf_or_nan: An `ExecutionCallback` determining the action for infinity 310 (`inf`) and NaN (`nan`) values. A value of `None` leads to no change in 311 the action of the condition. 312 313 Returns: 314 A dictionary of old actions. 315 316 Raises: 317 ValueError: If the value of any keyword arguments is invalid. 318 """ 319 inf_or_nan = ExecutionCallback(inf_or_nan) if inf_or_nan is not None else None 320 old_settings = {"inf_or_nan": ExecutionCallback.IGNORE} 321 default_context = context.context() 322 323 carryover_callbacks = [] 324 for callback in default_context.post_execution_callbacks: 325 # Check whether the callback is inf_nan_callback or a partial object of 326 # inf_nan_callback. 327 if (callback == inf_nan_callback or 328 isinstance(callback, functools.partial) and 329 callback.func == inf_nan_callback): 330 if callback == inf_nan_callback: 331 old_settings["inf_or_nan"] = _DEFAULT_CALLBACK_ACTION 332 else: 333 old_settings["inf_or_nan"] = callback.keywords.get( 334 "action", _DEFAULT_CALLBACK_ACTION) 335 elif inf_or_nan is not None: 336 carryover_callbacks.append(callback) 337 338 if inf_or_nan is not None: 339 default_context.clear_post_execution_callbacks() 340 for callback in carryover_callbacks: 341 default_context.add_post_execution_callback(callback) 342 if inf_or_nan != ExecutionCallback.IGNORE: 343 default_context.add_post_execution_callback( 344 functools.partial(inf_nan_callback, action=inf_or_nan)) 345 346 return old_settings 347 348 349@contextlib.contextmanager 350def errstate(inf_or_nan=None): 351 """Context manager setting error state. 352 353 Example: 354 ``` 355 c = tf.log(0.) # -inf 356 357 with errstate(inf_or_nan=ExecutionCallback.RAISE): 358 tf.log(0.) # <-- Raises InfOrNanError. 359 ``` 360 361 Args: 362 inf_or_nan: An `ExecutionCallback` determining the action for infinity 363 (`inf`) and NaN (`nan`) values. A value of `None` leads to no change in 364 the action of the condition. 365 366 Yields: 367 None. 368 369 Raises: 370 ValueError: If the value of any keyword arguments is invalid. 371 """ 372 if not context.executing_eagerly(): 373 yield 374 else: 375 old_settings = seterr(inf_or_nan=inf_or_nan) 376 yield 377 seterr(**old_settings) 378