• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for tensorflow.ops.numerics."""
16
17import numpy as np
18
19from tensorflow.python.framework import constant_op
20from tensorflow.python.framework import dtypes
21from tensorflow.python.framework import ops
22from tensorflow.python.framework import test_util
23from tensorflow.python.ops import array_ops
24from tensorflow.python.ops import control_flow_ops
25from tensorflow.python.ops import math_ops
26from tensorflow.python.ops import numerics
27from tensorflow.python.platform import test
28
29
30class VerifyTensorAllFiniteTest(test.TestCase):
31
32  def testVerifyTensorAllFiniteSucceeds(self):
33    x_shape = [5, 4]
34    x = np.random.random_sample(x_shape).astype(np.float32)
35    with test_util.use_gpu():
36      t = constant_op.constant(x, shape=x_shape, dtype=dtypes.float32)
37      t_verified = numerics.verify_tensor_all_finite(t,
38                                                     "Input is not a number.")
39      self.assertAllClose(x, self.evaluate(t_verified))
40
41  def testVerifyTensorAllFiniteFails(self):
42    x_shape = [5, 4]
43    x = np.random.random_sample(x_shape).astype(np.float32)
44    my_msg = "Input is not a number."
45
46    # Test NaN.
47    x[0] = np.nan
48    with test_util.use_gpu():
49      with self.assertRaisesOpError(my_msg):
50        t = constant_op.constant(x, shape=x_shape, dtype=dtypes.float32)
51        t_verified = numerics.verify_tensor_all_finite(t, my_msg)
52        self.evaluate(t_verified)
53
54    # Test Inf.
55    x[0] = np.inf
56    with test_util.use_gpu():
57      with self.assertRaisesOpError(my_msg):
58        t = constant_op.constant(x, shape=x_shape, dtype=dtypes.float32)
59        t_verified = numerics.verify_tensor_all_finite(t, my_msg)
60        self.evaluate(t_verified)
61
62
63@test_util.run_v1_only("add_check_numerics_op() is meant to be a v1-only API")
64class NumericsTest(test.TestCase):
65
66  def testInf(self):
67    with self.session(graph=ops.Graph()):
68      t1 = constant_op.constant(1.0)
69      t2 = constant_op.constant(0.0)
70      a = math_ops.div(t1, t2)
71      check = numerics.add_check_numerics_ops()
72      a = control_flow_ops.with_dependencies([check], a)
73      with self.assertRaisesOpError("Inf"):
74        self.evaluate(a)
75
76  def testNaN(self):
77    with self.session(graph=ops.Graph()):
78      t1 = constant_op.constant(0.0)
79      t2 = constant_op.constant(0.0)
80      a = math_ops.div(t1, t2)
81      check = numerics.add_check_numerics_ops()
82      a = control_flow_ops.with_dependencies([check], a)
83      with self.assertRaisesOpError("NaN"):
84        self.evaluate(a)
85
86  def testBoth(self):
87    with self.session(graph=ops.Graph()):
88      t1 = constant_op.constant([1.0, 0.0])
89      t2 = constant_op.constant([0.0, 0.0])
90      a = math_ops.div(t1, t2)
91      check = numerics.add_check_numerics_ops()
92      a = control_flow_ops.with_dependencies([check], a)
93      with self.assertRaisesOpError("Inf and NaN"):
94        self.evaluate(a)
95
96  def testPassThrough(self):
97    with self.session(graph=ops.Graph()):
98      t1 = constant_op.constant([1.0, 2.0, 3.0, 4.0, 5.0, 6.0], shape=[2, 3])
99      checked = array_ops.check_numerics(t1, message="pass through test")
100      value = self.evaluate(checked)
101      self.assertAllEqual(np.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]), value)
102      self.assertEqual([2, 3], checked.get_shape())
103
104  def testControlFlowCond(self):
105    predicate = array_ops.placeholder(dtypes.bool, shape=[])
106    _ = control_flow_ops.cond(predicate,
107                              lambda: constant_op.constant([37.]),
108                              lambda: constant_op.constant([42.]))
109    with self.assertRaisesRegex(
110        ValueError, r"`tf\.add_check_numerics_ops\(\) is not compatible with "
111        r"TensorFlow control flow operations such as `tf\.cond\(\)` "
112        r"or `tf.while_loop\(\)`\."):
113      numerics.add_check_numerics_ops()
114
115  def testControlFlowWhile(self):
116    predicate = array_ops.placeholder(dtypes.bool, shape=[])
117    _ = control_flow_ops.while_loop(lambda _: predicate,
118                                    lambda _: constant_op.constant([37.]),
119                                    [constant_op.constant([42.])])
120    with self.assertRaisesRegex(
121        ValueError, r"`tf\.add_check_numerics_ops\(\) is not compatible with "
122        r"TensorFlow control flow operations such as `tf\.cond\(\)` "
123        r"or `tf.while_loop\(\)`\."):
124      numerics.add_check_numerics_ops()
125
126
127if __name__ == "__main__":
128  test.main()
129