1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Functional tests for 3d convolutional operations.""" 16 17import itertools 18 19import numpy as np 20 21from tensorflow.python.framework import constant_op 22from tensorflow.python.framework import dtypes 23from tensorflow.python.framework import test_util 24from tensorflow.python.ops import array_ops 25from tensorflow.python.ops import gradient_checker 26from tensorflow.python.ops import gradients_impl 27from tensorflow.python.ops import math_ops 28from tensorflow.python.platform import test 29from tensorflow.python.platform import tf_logging 30 31 32class BetaincTest(test.TestCase): 33 34 def _testBetaInc(self, a_s, b_s, x_s, dtype): 35 try: 36 from scipy import special # pylint: disable=g-import-not-at-top 37 np_dt = dtype.as_numpy_dtype 38 39 # Test random values 40 a_s = a_s.astype(np_dt) # in (0, infty) 41 b_s = b_s.astype(np_dt) # in (0, infty) 42 x_s = x_s.astype(np_dt) # in (0, 1) 43 tf_a_s = constant_op.constant(a_s, dtype=dtype) 44 tf_b_s = constant_op.constant(b_s, dtype=dtype) 45 tf_x_s = constant_op.constant(x_s, dtype=dtype) 46 tf_out_t = math_ops.betainc(tf_a_s, tf_b_s, tf_x_s) 47 with self.cached_session(): 48 tf_out = self.evaluate(tf_out_t) 49 scipy_out = special.betainc(a_s, b_s, x_s, dtype=np_dt) 50 51 # the scipy version of betainc uses a double-only implementation. 52 # TODO(ebrevdo): identify reasons for (sometime) precision loss 53 # with doubles 54 rtol = 1e-4 55 atol = 1e-5 56 self.assertAllCloseAccordingToType( 57 scipy_out, tf_out, rtol=rtol, atol=atol) 58 59 # Test out-of-range values (most should return nan output) 60 combinations = list(itertools.product([-1, 0, 0.5, 1.0, 1.5], repeat=3)) 61 a_comb, b_comb, x_comb = np.asarray(list(zip(*combinations)), dtype=np_dt) 62 with self.cached_session(): 63 tf_comb = math_ops.betainc(a_comb, b_comb, x_comb).eval() 64 scipy_comb = special.betainc(a_comb, b_comb, x_comb, dtype=np_dt) 65 self.assertAllCloseAccordingToType( 66 scipy_comb, tf_comb, rtol=rtol, atol=atol) 67 68 # Test broadcasting between scalars and other shapes 69 with self.cached_session(): 70 self.assertAllCloseAccordingToType( 71 special.betainc(0.1, b_s, x_s, dtype=np_dt), 72 math_ops.betainc(0.1, b_s, x_s).eval(), 73 rtol=rtol, 74 atol=atol) 75 self.assertAllCloseAccordingToType( 76 special.betainc(a_s, 0.1, x_s, dtype=np_dt), 77 math_ops.betainc(a_s, 0.1, x_s).eval(), 78 rtol=rtol, 79 atol=atol) 80 self.assertAllCloseAccordingToType( 81 special.betainc(a_s, b_s, 0.1, dtype=np_dt), 82 math_ops.betainc(a_s, b_s, 0.1).eval(), 83 rtol=rtol, 84 atol=atol) 85 self.assertAllCloseAccordingToType( 86 special.betainc(0.1, b_s, 0.1, dtype=np_dt), 87 math_ops.betainc(0.1, b_s, 0.1).eval(), 88 rtol=rtol, 89 atol=atol) 90 self.assertAllCloseAccordingToType( 91 special.betainc(0.1, 0.1, 0.1, dtype=np_dt), 92 math_ops.betainc(0.1, 0.1, 0.1).eval(), 93 rtol=rtol, 94 atol=atol) 95 96 with self.assertRaisesRegex(ValueError, "must be equal"): 97 math_ops.betainc(0.5, [0.5], [[0.5]]) 98 99 with self.cached_session(): 100 with self.assertRaisesOpError("Shapes of .* are inconsistent"): 101 a_p = array_ops.placeholder(dtype) 102 b_p = array_ops.placeholder(dtype) 103 x_p = array_ops.placeholder(dtype) 104 math_ops.betainc(a_p, b_p, x_p).eval( 105 feed_dict={a_p: 0.5, 106 b_p: [0.5], 107 x_p: [[0.5]]}) 108 109 except ImportError as e: 110 tf_logging.warn("Cannot test special functions: %s" % str(e)) 111 112 @test_util.run_deprecated_v1 113 def testBetaIncFloat(self): 114 a_s = np.abs(np.random.randn(10, 10) * 30) # in (0, infty) 115 b_s = np.abs(np.random.randn(10, 10) * 30) # in (0, infty) 116 x_s = np.random.rand(10, 10) # in (0, 1) 117 self._testBetaInc(a_s, b_s, x_s, dtypes.float32) 118 119 @test_util.run_deprecated_v1 120 def testBetaIncDouble(self): 121 a_s = np.abs(np.random.randn(10, 10) * 30) # in (0, infty) 122 b_s = np.abs(np.random.randn(10, 10) * 30) # in (0, infty) 123 x_s = np.random.rand(10, 10) # in (0, 1) 124 self._testBetaInc(a_s, b_s, x_s, dtypes.float64) 125 126 @test_util.run_deprecated_v1 127 def testBetaIncDoubleVeryLargeValues(self): 128 a_s = np.abs(np.random.randn(10, 10) * 1e15) # in (0, infty) 129 b_s = np.abs(np.random.randn(10, 10) * 1e15) # in (0, infty) 130 x_s = np.random.rand(10, 10) # in (0, 1) 131 self._testBetaInc(a_s, b_s, x_s, dtypes.float64) 132 133 @test_util.run_deprecated_v1 134 @test_util.disable_xla("b/178338235") 135 def testBetaIncDoubleVerySmallValues(self): 136 a_s = np.abs(np.random.randn(10, 10) * 1e-16) # in (0, infty) 137 b_s = np.abs(np.random.randn(10, 10) * 1e-16) # in (0, infty) 138 x_s = np.random.rand(10, 10) # in (0, 1) 139 self._testBetaInc(a_s, b_s, x_s, dtypes.float64) 140 141 @test_util.run_deprecated_v1 142 @test_util.disable_xla("b/178338235") 143 def testBetaIncFloatVerySmallValues(self): 144 a_s = np.abs(np.random.randn(10, 10) * 1e-8) # in (0, infty) 145 b_s = np.abs(np.random.randn(10, 10) * 1e-8) # in (0, infty) 146 x_s = np.random.rand(10, 10) # in (0, 1) 147 self._testBetaInc(a_s, b_s, x_s, dtypes.float32) 148 149 @test_util.run_deprecated_v1 150 def testBetaIncFpropAndBpropAreNeverNAN(self): 151 with self.cached_session() as sess: 152 space = np.logspace(-8, 5).tolist() 153 space_x = np.linspace(1e-16, 1 - 1e-16).tolist() 154 ga_s, gb_s, gx_s = zip(*list(itertools.product(space, space, space_x))) 155 # Test grads are never nan 156 ga_s_t = constant_op.constant(ga_s, dtype=dtypes.float32) 157 gb_s_t = constant_op.constant(gb_s, dtype=dtypes.float32) 158 gx_s_t = constant_op.constant(gx_s, dtype=dtypes.float32) 159 tf_gout_t = math_ops.betainc(ga_s_t, gb_s_t, gx_s_t) 160 tf_gout, grads_x = sess.run( 161 [tf_gout_t, 162 gradients_impl.gradients(tf_gout_t, [ga_s_t, gb_s_t, gx_s_t])[2]]) 163 164 # Equivalent to `assertAllFalse` (if it existed). 165 self.assertAllEqual( 166 np.zeros_like(grads_x).astype(np.bool_), np.isnan(tf_gout)) 167 self.assertAllEqual( 168 np.zeros_like(grads_x).astype(np.bool_), np.isnan(grads_x)) 169 170 @test_util.run_deprecated_v1 171 def testBetaIncGrads(self): 172 err_tolerance = 1e-3 173 with self.cached_session(): 174 # Test gradient 175 ga_s = np.abs(np.random.randn(2, 2) * 30) # in (0, infty) 176 gb_s = np.abs(np.random.randn(2, 2) * 30) # in (0, infty) 177 gx_s = np.random.rand(2, 2) # in (0, 1) 178 tf_ga_s = constant_op.constant(ga_s, dtype=dtypes.float64) 179 tf_gb_s = constant_op.constant(gb_s, dtype=dtypes.float64) 180 tf_gx_s = constant_op.constant(gx_s, dtype=dtypes.float64) 181 tf_gout_t = math_ops.betainc(tf_ga_s, tf_gb_s, tf_gx_s) 182 err = gradient_checker.compute_gradient_error( 183 [tf_gx_s], [gx_s.shape], tf_gout_t, gx_s.shape) 184 tf_logging.info("betainc gradient err = %g " % err) 185 self.assertLess(err, err_tolerance) 186 187 # Test broadcast gradient 188 gx_s = np.random.rand() # in (0, 1) 189 tf_gx_s = constant_op.constant(gx_s, dtype=dtypes.float64) 190 tf_gout_t = math_ops.betainc(tf_ga_s, tf_gb_s, tf_gx_s) 191 err = gradient_checker.compute_gradient_error( 192 [tf_gx_s], [()], tf_gout_t, ga_s.shape) 193 tf_logging.info("betainc gradient err = %g " % err) 194 self.assertLess(err, err_tolerance) 195 196 197if __name__ == "__main__": 198 test.main() 199