• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Functional tests for 3d convolutional operations."""
16
17import itertools
18
19import numpy as np
20
21from tensorflow.python.framework import constant_op
22from tensorflow.python.framework import dtypes
23from tensorflow.python.framework import test_util
24from tensorflow.python.ops import array_ops
25from tensorflow.python.ops import gradient_checker
26from tensorflow.python.ops import gradients_impl
27from tensorflow.python.ops import math_ops
28from tensorflow.python.platform import test
29from tensorflow.python.platform import tf_logging
30
31
32class BetaincTest(test.TestCase):
33
34  def _testBetaInc(self, a_s, b_s, x_s, dtype):
35    try:
36      from scipy import special  # pylint: disable=g-import-not-at-top
37      np_dt = dtype.as_numpy_dtype
38
39      # Test random values
40      a_s = a_s.astype(np_dt)  # in (0, infty)
41      b_s = b_s.astype(np_dt)  # in (0, infty)
42      x_s = x_s.astype(np_dt)  # in (0, 1)
43      tf_a_s = constant_op.constant(a_s, dtype=dtype)
44      tf_b_s = constant_op.constant(b_s, dtype=dtype)
45      tf_x_s = constant_op.constant(x_s, dtype=dtype)
46      tf_out_t = math_ops.betainc(tf_a_s, tf_b_s, tf_x_s)
47      with self.cached_session():
48        tf_out = self.evaluate(tf_out_t)
49      scipy_out = special.betainc(a_s, b_s, x_s, dtype=np_dt)
50
51      # the scipy version of betainc uses a double-only implementation.
52      # TODO(ebrevdo): identify reasons for (sometime) precision loss
53      # with doubles
54      rtol = 1e-4
55      atol = 1e-5
56      self.assertAllCloseAccordingToType(
57          scipy_out, tf_out, rtol=rtol, atol=atol)
58
59      # Test out-of-range values (most should return nan output)
60      combinations = list(itertools.product([-1, 0, 0.5, 1.0, 1.5], repeat=3))
61      a_comb, b_comb, x_comb = np.asarray(list(zip(*combinations)), dtype=np_dt)
62      with self.cached_session():
63        tf_comb = math_ops.betainc(a_comb, b_comb, x_comb).eval()
64      scipy_comb = special.betainc(a_comb, b_comb, x_comb, dtype=np_dt)
65      self.assertAllCloseAccordingToType(
66          scipy_comb, tf_comb, rtol=rtol, atol=atol)
67
68      # Test broadcasting between scalars and other shapes
69      with self.cached_session():
70        self.assertAllCloseAccordingToType(
71            special.betainc(0.1, b_s, x_s, dtype=np_dt),
72            math_ops.betainc(0.1, b_s, x_s).eval(),
73            rtol=rtol,
74            atol=atol)
75        self.assertAllCloseAccordingToType(
76            special.betainc(a_s, 0.1, x_s, dtype=np_dt),
77            math_ops.betainc(a_s, 0.1, x_s).eval(),
78            rtol=rtol,
79            atol=atol)
80        self.assertAllCloseAccordingToType(
81            special.betainc(a_s, b_s, 0.1, dtype=np_dt),
82            math_ops.betainc(a_s, b_s, 0.1).eval(),
83            rtol=rtol,
84            atol=atol)
85        self.assertAllCloseAccordingToType(
86            special.betainc(0.1, b_s, 0.1, dtype=np_dt),
87            math_ops.betainc(0.1, b_s, 0.1).eval(),
88            rtol=rtol,
89            atol=atol)
90        self.assertAllCloseAccordingToType(
91            special.betainc(0.1, 0.1, 0.1, dtype=np_dt),
92            math_ops.betainc(0.1, 0.1, 0.1).eval(),
93            rtol=rtol,
94            atol=atol)
95
96      with self.assertRaisesRegex(ValueError, "must be equal"):
97        math_ops.betainc(0.5, [0.5], [[0.5]])
98
99      with self.cached_session():
100        with self.assertRaisesOpError("Shapes of .* are inconsistent"):
101          a_p = array_ops.placeholder(dtype)
102          b_p = array_ops.placeholder(dtype)
103          x_p = array_ops.placeholder(dtype)
104          math_ops.betainc(a_p, b_p, x_p).eval(
105              feed_dict={a_p: 0.5,
106                         b_p: [0.5],
107                         x_p: [[0.5]]})
108
109    except ImportError as e:
110      tf_logging.warn("Cannot test special functions: %s" % str(e))
111
112  @test_util.run_deprecated_v1
113  def testBetaIncFloat(self):
114    a_s = np.abs(np.random.randn(10, 10) * 30)  # in (0, infty)
115    b_s = np.abs(np.random.randn(10, 10) * 30)  # in (0, infty)
116    x_s = np.random.rand(10, 10)  # in (0, 1)
117    self._testBetaInc(a_s, b_s, x_s, dtypes.float32)
118
119  @test_util.run_deprecated_v1
120  def testBetaIncDouble(self):
121    a_s = np.abs(np.random.randn(10, 10) * 30)  # in (0, infty)
122    b_s = np.abs(np.random.randn(10, 10) * 30)  # in (0, infty)
123    x_s = np.random.rand(10, 10)  # in (0, 1)
124    self._testBetaInc(a_s, b_s, x_s, dtypes.float64)
125
126  @test_util.run_deprecated_v1
127  def testBetaIncDoubleVeryLargeValues(self):
128    a_s = np.abs(np.random.randn(10, 10) * 1e15)  # in (0, infty)
129    b_s = np.abs(np.random.randn(10, 10) * 1e15)  # in (0, infty)
130    x_s = np.random.rand(10, 10)  # in (0, 1)
131    self._testBetaInc(a_s, b_s, x_s, dtypes.float64)
132
133  @test_util.run_deprecated_v1
134  @test_util.disable_xla("b/178338235")
135  def testBetaIncDoubleVerySmallValues(self):
136    a_s = np.abs(np.random.randn(10, 10) * 1e-16)  # in (0, infty)
137    b_s = np.abs(np.random.randn(10, 10) * 1e-16)  # in (0, infty)
138    x_s = np.random.rand(10, 10)  # in (0, 1)
139    self._testBetaInc(a_s, b_s, x_s, dtypes.float64)
140
141  @test_util.run_deprecated_v1
142  @test_util.disable_xla("b/178338235")
143  def testBetaIncFloatVerySmallValues(self):
144    a_s = np.abs(np.random.randn(10, 10) * 1e-8)  # in (0, infty)
145    b_s = np.abs(np.random.randn(10, 10) * 1e-8)  # in (0, infty)
146    x_s = np.random.rand(10, 10)  # in (0, 1)
147    self._testBetaInc(a_s, b_s, x_s, dtypes.float32)
148
149  @test_util.run_deprecated_v1
150  def testBetaIncFpropAndBpropAreNeverNAN(self):
151    with self.cached_session() as sess:
152      space = np.logspace(-8, 5).tolist()
153      space_x = np.linspace(1e-16, 1 - 1e-16).tolist()
154      ga_s, gb_s, gx_s = zip(*list(itertools.product(space, space, space_x)))
155      # Test grads are never nan
156      ga_s_t = constant_op.constant(ga_s, dtype=dtypes.float32)
157      gb_s_t = constant_op.constant(gb_s, dtype=dtypes.float32)
158      gx_s_t = constant_op.constant(gx_s, dtype=dtypes.float32)
159      tf_gout_t = math_ops.betainc(ga_s_t, gb_s_t, gx_s_t)
160      tf_gout, grads_x = sess.run(
161          [tf_gout_t,
162           gradients_impl.gradients(tf_gout_t, [ga_s_t, gb_s_t, gx_s_t])[2]])
163
164      # Equivalent to `assertAllFalse` (if it existed).
165      self.assertAllEqual(
166          np.zeros_like(grads_x).astype(np.bool_), np.isnan(tf_gout))
167      self.assertAllEqual(
168          np.zeros_like(grads_x).astype(np.bool_), np.isnan(grads_x))
169
170  @test_util.run_deprecated_v1
171  def testBetaIncGrads(self):
172    err_tolerance = 1e-3
173    with self.cached_session():
174      # Test gradient
175      ga_s = np.abs(np.random.randn(2, 2) * 30)  # in (0, infty)
176      gb_s = np.abs(np.random.randn(2, 2) * 30)  # in (0, infty)
177      gx_s = np.random.rand(2, 2)  # in (0, 1)
178      tf_ga_s = constant_op.constant(ga_s, dtype=dtypes.float64)
179      tf_gb_s = constant_op.constant(gb_s, dtype=dtypes.float64)
180      tf_gx_s = constant_op.constant(gx_s, dtype=dtypes.float64)
181      tf_gout_t = math_ops.betainc(tf_ga_s, tf_gb_s, tf_gx_s)
182      err = gradient_checker.compute_gradient_error(
183          [tf_gx_s], [gx_s.shape], tf_gout_t, gx_s.shape)
184      tf_logging.info("betainc gradient err = %g " % err)
185      self.assertLess(err, err_tolerance)
186
187      # Test broadcast gradient
188      gx_s = np.random.rand()  # in (0, 1)
189      tf_gx_s = constant_op.constant(gx_s, dtype=dtypes.float64)
190      tf_gout_t = math_ops.betainc(tf_ga_s, tf_gb_s, tf_gx_s)
191      err = gradient_checker.compute_gradient_error(
192          [tf_gx_s], [()], tf_gout_t, ga_s.shape)
193      tf_logging.info("betainc gradient err = %g " % err)
194      self.assertLess(err, err_tolerance)
195
196
197if __name__ == "__main__":
198  test.main()
199