1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for tensorflow.ops.numerics.""" 16 17import numpy as np 18 19from tensorflow.python.framework import constant_op 20from tensorflow.python.framework import dtypes 21from tensorflow.python.framework import ops 22from tensorflow.python.framework import test_util 23from tensorflow.python.ops import array_ops 24from tensorflow.python.ops import control_flow_ops 25from tensorflow.python.ops import math_ops 26from tensorflow.python.ops import numerics 27from tensorflow.python.platform import test 28 29 30class VerifyTensorAllFiniteTest(test.TestCase): 31 32 def testVerifyTensorAllFiniteSucceeds(self): 33 x_shape = [5, 4] 34 x = np.random.random_sample(x_shape).astype(np.float32) 35 with test_util.use_gpu(): 36 t = constant_op.constant(x, shape=x_shape, dtype=dtypes.float32) 37 t_verified = numerics.verify_tensor_all_finite(t, 38 "Input is not a number.") 39 self.assertAllClose(x, self.evaluate(t_verified)) 40 41 def testVerifyTensorAllFiniteFails(self): 42 x_shape = [5, 4] 43 x = np.random.random_sample(x_shape).astype(np.float32) 44 my_msg = "Input is not a number." 45 46 # Test NaN. 47 x[0] = np.nan 48 with test_util.use_gpu(): 49 with self.assertRaisesOpError(my_msg): 50 t = constant_op.constant(x, shape=x_shape, dtype=dtypes.float32) 51 t_verified = numerics.verify_tensor_all_finite(t, my_msg) 52 self.evaluate(t_verified) 53 54 # Test Inf. 55 x[0] = np.inf 56 with test_util.use_gpu(): 57 with self.assertRaisesOpError(my_msg): 58 t = constant_op.constant(x, shape=x_shape, dtype=dtypes.float32) 59 t_verified = numerics.verify_tensor_all_finite(t, my_msg) 60 self.evaluate(t_verified) 61 62 63@test_util.run_v1_only("add_check_numerics_op() is meant to be a v1-only API") 64class NumericsTest(test.TestCase): 65 66 def testInf(self): 67 with self.session(graph=ops.Graph()): 68 t1 = constant_op.constant(1.0) 69 t2 = constant_op.constant(0.0) 70 a = math_ops.div(t1, t2) 71 check = numerics.add_check_numerics_ops() 72 a = control_flow_ops.with_dependencies([check], a) 73 with self.assertRaisesOpError("Inf"): 74 self.evaluate(a) 75 76 def testNaN(self): 77 with self.session(graph=ops.Graph()): 78 t1 = constant_op.constant(0.0) 79 t2 = constant_op.constant(0.0) 80 a = math_ops.div(t1, t2) 81 check = numerics.add_check_numerics_ops() 82 a = control_flow_ops.with_dependencies([check], a) 83 with self.assertRaisesOpError("NaN"): 84 self.evaluate(a) 85 86 def testBoth(self): 87 with self.session(graph=ops.Graph()): 88 t1 = constant_op.constant([1.0, 0.0]) 89 t2 = constant_op.constant([0.0, 0.0]) 90 a = math_ops.div(t1, t2) 91 check = numerics.add_check_numerics_ops() 92 a = control_flow_ops.with_dependencies([check], a) 93 with self.assertRaisesOpError("Inf and NaN"): 94 self.evaluate(a) 95 96 def testPassThrough(self): 97 with self.session(graph=ops.Graph()): 98 t1 = constant_op.constant([1.0, 2.0, 3.0, 4.0, 5.0, 6.0], shape=[2, 3]) 99 checked = array_ops.check_numerics(t1, message="pass through test") 100 value = self.evaluate(checked) 101 self.assertAllEqual(np.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]), value) 102 self.assertEqual([2, 3], checked.get_shape()) 103 104 def testControlFlowCond(self): 105 predicate = array_ops.placeholder(dtypes.bool, shape=[]) 106 _ = control_flow_ops.cond(predicate, 107 lambda: constant_op.constant([37.]), 108 lambda: constant_op.constant([42.])) 109 with self.assertRaisesRegex( 110 ValueError, r"`tf\.add_check_numerics_ops\(\) is not compatible with " 111 r"TensorFlow control flow operations such as `tf\.cond\(\)` " 112 r"or `tf.while_loop\(\)`\."): 113 numerics.add_check_numerics_ops() 114 115 def testControlFlowWhile(self): 116 predicate = array_ops.placeholder(dtypes.bool, shape=[]) 117 _ = control_flow_ops.while_loop(lambda _: predicate, 118 lambda _: constant_op.constant([37.]), 119 [constant_op.constant([42.])]) 120 with self.assertRaisesRegex( 121 ValueError, r"`tf\.add_check_numerics_ops\(\) is not compatible with " 122 r"TensorFlow control flow operations such as `tf\.cond\(\)` " 123 r"or `tf.while_loop\(\)`\."): 124 numerics.add_check_numerics_ops() 125 126 127if __name__ == "__main__": 128 test.main() 129