1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for eager execution_callbacks.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21from tensorflow.python.eager import execution_callbacks 22from tensorflow.python.framework import constant_op 23from tensorflow.python.framework import ops 24from tensorflow.python.ops import math_ops 25from tensorflow.python.platform import test 26 27RAISE = execution_callbacks.ExecutionCallback.RAISE 28IGNORE = execution_callbacks.ExecutionCallback.IGNORE 29 30 31def log_zero(): 32 """Computes `log(0.0)`.""" 33 return math_ops.log(constant_op.constant(0.)) 34 35 36class ExecutionCallbacksTest(test.TestCase): 37 38 def test_errstate_inf_raise(self): 39 with execution_callbacks.errstate(inf_or_nan=RAISE): 40 with self.assertRaises(execution_callbacks.InfOrNanError): 41 log_zero() 42 43 def test_errstate_inf_ignore(self): 44 with execution_callbacks.errstate(inf_or_nan=IGNORE): 45 self.assertEqual(-float("inf"), log_zero().numpy()) 46 47 def test_errstate_nesting(self): 48 with execution_callbacks.errstate(inf_or_nan=RAISE): 49 with execution_callbacks.errstate(inf_or_nan=IGNORE): 50 self.assertEqual(-float("inf"), log_zero().numpy()) 51 52 with self.assertRaises(execution_callbacks.InfOrNanError): 53 log_zero() 54 55 56if __name__ == "__main__": 57 ops.enable_eager_execution() 58 test.main() 59