• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15
16"""Tests for special math operations."""
17
18from __future__ import absolute_import
19from __future__ import division
20from __future__ import print_function
21
22import os
23
24from absl import flags
25from absl.testing import parameterized
26
27import numpy as np
28import scipy.special as sps
29import six
30
31from tensorflow.compiler.tests import xla_test
32from tensorflow.python.eager import def_function
33from tensorflow.python.framework import constant_op
34from tensorflow.python.ops import gen_math_ops
35from tensorflow.python.ops import gen_random_ops
36from tensorflow.python.ops import gradient_checker_v2
37from tensorflow.python.ops import math_ops
38from tensorflow.python.platform import test
39
40flags.DEFINE_bool('vary_seed', False,
41                  ('Whether to vary the PRNG seed unpredictably.  '
42                   'With --runs_per_test=N, produces N iid runs.'))
43
44NUM_SAMPLES = int(1e3)
45
46
47@def_function.function(jit_compile=True)
48def _igamma(a, x):
49  return math_ops.igamma(a, x)
50
51
52@def_function.function(jit_compile=True)
53def _igammac(a, x):
54  return math_ops.igammac(a, x)
55
56
57@def_function.function(jit_compile=True)
58def _polygamma(n, x):
59  return math_ops.polygamma(n, x)
60
61
62@def_function.function(jit_compile=True)
63def _zeta(a, q):
64  return math_ops.zeta(a, q)
65
66
67# This is df/da / df/dx, where f = igamma.
68def implicit_reparameterization_grad(a, x):
69  log_prob = math_ops.xlogy(a - 1., x) - math_ops.lgamma(a) - x
70  prob = math_ops.exp(log_prob)
71  return -gen_math_ops.igamma_grad_a(a, x) / prob
72
73
74@def_function.function(jit_compile=True)
75def _log1p(x):
76  return math_ops.log1p(x)
77
78
79class Log1pTest(xla_test.XLATestCase, parameterized.TestCase):
80
81  def setUp(self):
82    if flags.FLAGS.vary_seed:
83      entropy = os.urandom(64)
84      if six.PY2:
85        answer = int(entropy.encode('hex'), 16)
86      else:
87        answer = int.from_bytes(entropy, 'big')
88      np.random.seed(answer % (2**32 - 1))
89    super(Log1pTest, self).setUp()
90
91  def adjust_tolerance_for_tpu(self, dtype, rtol, atol):
92    if self.device not in ['TPU']:
93      return rtol, atol
94
95    if dtype == np.float32:
96      return 4e-4, 0.
97    return 1e-10, 0.
98
99  def _test_range(self, low, high, dtype, rtol, atol, is_negative=False):
100    # Test values near zero.
101    rtol, atol = self.adjust_tolerance_for_tpu(dtype, rtol, atol)
102    x = np.exp(np.random.uniform(
103        low=low, high=high, size=[NUM_SAMPLES])).astype(dtype)
104    if is_negative:
105      x = -x
106    expected_values = np.log1p(x)
107    with self.session() as sess:
108      with self.test_scope():
109        actual = _log1p(x)
110      actual = sess.run(actual)
111    self.assertAllClose(expected_values, actual, atol=atol, rtol=rtol)
112
113  @parameterized.parameters((np.float32, 1e-7, 0.),
114                            (np.float64, 1e-15, 0.))
115  def testSmallX(self, dtype, rtol, atol):
116    self._test_range(-40., -20., dtype, rtol, atol, is_negative=False)
117    self._test_range(-40., -20., dtype, rtol, atol, is_negative=True)
118
119  @parameterized.parameters((np.float32, 2e-7, 0.),
120                            (np.float64, 1e-15, 0.))
121  def testGreaterThanNegativeTwentyExponent(self, dtype, rtol, atol):
122    self._test_range(-20., -10., dtype, rtol, atol, is_negative=False)
123    self._test_range(-20., -10., dtype, rtol, atol, is_negative=True)
124
125  @parameterized.parameters((np.float32, 2e-7, 0.),
126                            (np.float64, 1e-15, 0.))
127  def testGreaterThanNegativeTenExponent(self, dtype, rtol, atol):
128    self._test_range(-10., -5., dtype, rtol, atol, is_negative=False)
129    self._test_range(-10., -5., dtype, rtol, atol, is_negative=True)
130
131  @parameterized.parameters((np.float32, 2e-7, 0.),
132                            (np.float64, 1e-15, 0.))
133  def testGreaterThanNegativeFiveExponent(self, dtype, rtol, atol):
134    self._test_range(-5., -1., dtype, rtol, atol, is_negative=False)
135    self._test_range(-5., -1., dtype, rtol, atol, is_negative=True)
136
137  @parameterized.parameters((np.float32, 4e-7, 0.),
138                            (np.float64, 3e-14, 0.))
139  def testXGreaterThanOneTenth(self, dtype, rtol, atol):
140    self._test_range(-1., 0., dtype, rtol, atol, is_negative=False)
141    self._test_range(-1., 0., dtype, rtol, atol, is_negative=True)
142
143  @parameterized.parameters((np.float32, 2e-7, 0.),
144                            (np.float64, 2e-15, 0.))
145  def testXGreaterThanOne(self, dtype, rtol, atol):
146    self._test_range(0., 3., dtype, rtol, atol, is_negative=False)
147
148
149class ZetaTest(xla_test.XLATestCase, parameterized.TestCase):
150
151  def setUp(self):
152    if flags.FLAGS.vary_seed:
153      entropy = os.urandom(64)
154      if six.PY2:
155        answer = int(entropy.encode('hex'), 16)
156      else:
157        answer = int.from_bytes(entropy, 'big')
158      np.random.seed(answer % (2**32 - 1))
159    super(ZetaTest, self).setUp()
160
161  def adjust_tolerance_for_tpu(self, dtype, rtol, atol):
162    if self.device not in ['TPU']:
163      return rtol, atol
164
165    if dtype == np.float32:
166      return 2e-2, 1e-7
167    return 2e-4, 1e-20
168
169  def testBadValues(self):
170    q = np.random.uniform(low=0.3, high=20., size=[10])
171    with self.session() as sess:
172      with self.test_scope():
173        y = _zeta(np.float64(1.), q)
174      actual = sess.run(y)
175    # When x == 1, this is the Harmonic series.
176    self.assertTrue(np.all(np.isinf(actual)))
177
178    with self.session() as sess:
179      with self.test_scope():
180        y = _zeta(np.float64(0.1), q)
181      actual = sess.run(y)
182    # When x < 1, this is undefined.
183    self.assertTrue(np.all(np.isnan(actual)))
184
185    with self.session() as sess:
186      with self.test_scope():
187        y = _zeta([1.1, 1.2, 2.1, 2.2, 3.1], [-2.0, -1.1, -1.0, -0.5, -0.1])
188      actual = sess.run(y)
189    # For q <= 0, x must be an integer.
190    self.assertTrue(np.all(np.isnan(actual)))
191
192    with self.session() as sess:
193      with self.test_scope():
194        y = _zeta([2.0, 4.0, 6.0], [0.0, -1.0, -2.0])
195      actual = sess.run(y)
196    # For integer q <= 0, zeta has poles with a defined limit of +inf where x is
197    # an even integer.
198    self.assertTrue(np.all(np.isinf(actual)))
199
200    with self.session() as sess:
201      with self.test_scope():
202        y = _zeta([3.0, 5.0, 7.0], [0.0, -1.0, -2.0])
203      actual = sess.run(y)
204    # For non-positive integer q, zeta has poles with an undefined limit where x
205    # is an odd integer.
206    self.assertTrue(np.all(np.isnan(actual)))
207
208    with self.session() as sess:
209      with self.test_scope():
210        y = _zeta([1.1, 2.2, 3.3], [-1.1, -1.0, 0.0])
211      actual = sess.run(y)
212    # For non-positive q, zeta is not defined if x is not an integer.
213    self.assertTrue(np.all(np.isnan(actual)))
214
215  @parameterized.parameters((np.float32, 1e-2, 1e-11),
216                            (np.float64, 1e-4, 1e-30))
217  def testLargeXSmallQ(self, dtype, rtol, atol):
218    rtol, atol = self.adjust_tolerance_for_tpu(dtype, rtol, atol)
219    if self.device not in ['XLA_GPU', 'XLA_CPU'] and dtype == np.float64:
220      # TODO(b/165739664): Figure out why on TPU F64 Zeta sometimes returns
221      # infs.
222      self.skipTest(
223          'Skipping test because some F64 operations are numerically '
224          'unstable on TPU.')
225
226    x = np.random.uniform(low=100., high=200., size=[NUM_SAMPLES]).astype(dtype)
227    q = np.random.uniform(low=0.3, high=1., size=[NUM_SAMPLES]).astype(dtype)
228
229    expected_values = sps.zeta(x, q)
230    with self.session() as sess:
231      with self.test_scope():
232        y = _zeta(x, q)
233      actual = sess.run(y)
234
235    self.assertAllClose(expected_values, actual, atol=atol, rtol=rtol)
236
237  @parameterized.parameters((np.float32, 1e-2, 1e-11),
238                            (np.float64, 1e-4, 1e-30))
239  def testSmallValues(self, dtype, rtol, atol):
240    rtol, atol = self.adjust_tolerance_for_tpu(dtype, rtol, atol)
241    # Test values near zero.
242    x = np.random.uniform(low=1.1, high=10., size=[NUM_SAMPLES]).astype(dtype)
243    q = np.random.uniform(
244        low=np.finfo(dtype).tiny, high=1., size=[NUM_SAMPLES]).astype(dtype)
245
246    expected_values = sps.zeta(x, q)
247    with self.session() as sess:
248      with self.test_scope():
249        actual = sess.run(_zeta(x, q))
250    self.assertAllClose(expected_values, actual, atol=atol, rtol=rtol)
251
252  @parameterized.parameters((np.float32, 1e-2, 1e-11),
253                            (np.float64, 1e-4, 1e-30))
254  def testMediumValues(self, dtype, rtol, atol):
255    rtol, atol = self.adjust_tolerance_for_tpu(dtype, rtol, atol)
256    x = np.random.uniform(low=1.1, high=100., size=[NUM_SAMPLES]).astype(dtype)
257    q = np.random.uniform(low=1., high=1e1, size=[NUM_SAMPLES]).astype(dtype)
258
259    expected_values = sps.zeta(x, q)
260    with self.session() as sess:
261      with self.test_scope():
262        actual = sess.run(_zeta(x, q))
263    self.assertAllClose(expected_values, actual, atol=atol, rtol=rtol)
264
265  @parameterized.parameters((np.float32, 2e-2, 1e-5), (np.float64, 1e-4, 1e-30))
266  def testLargeValues(self, dtype, rtol, atol):
267    x = np.random.uniform(
268        low=100., high=int(1e3), size=[NUM_SAMPLES]).astype(dtype)
269    q = np.random.uniform(
270        low=1., high=int(1e1), size=[NUM_SAMPLES]).astype(dtype)
271
272    expected_values = sps.zeta(x, q)
273    with self.session() as sess:
274      with self.test_scope():
275        actual = sess.run(_zeta(x, q))
276    self.assertAllClose(expected_values, actual, atol=atol, rtol=rtol)
277
278
279class PolygammaTest(xla_test.XLATestCase, parameterized.TestCase):
280
281  def setUp(self):
282    if flags.FLAGS.vary_seed:
283      entropy = os.urandom(64)
284      if six.PY2:
285        answer = int(entropy.encode('hex'), 16)
286      else:
287        answer = int.from_bytes(entropy, 'big')
288      np.random.seed(answer % (2**32 - 1))
289    super(PolygammaTest, self).setUp()
290
291  def adjust_tolerance_for_tpu(self, dtype, rtol, atol):
292    if self.device not in ['TPU']:
293      return rtol, atol
294
295    if dtype == np.float32:
296      return 2e-2, 1e-7
297    return 2e-4, 1e-20
298
299  def testBadValues(self):
300    x = np.random.uniform(low=0.3, high=20., size=[10])
301    with self.session() as sess:
302      with self.test_scope():
303        y = _polygamma(np.float64(-1.), x)
304      actual = sess.run(y)
305    # Not defined for negative numbers.
306    self.assertTrue(np.all(np.isnan(actual)))
307
308    with self.session() as sess:
309      with self.test_scope():
310        y = _polygamma(np.float64(0.1), x)
311      actual = sess.run(y)
312    # Not defined for non-integers.
313    self.assertTrue(np.all(np.isnan(actual)))
314
315  @parameterized.parameters((np.float32, 1e-2, 1e-11),
316                            (np.float64, 1e-4, 1e-30))
317  def testRecoverDigamma(self, dtype, rtol, atol):
318    rtol, atol = self.adjust_tolerance_for_tpu(dtype, rtol, atol)
319    if self.device not in ['XLA_GPU', 'XLA_CPU'] and dtype == np.float64:
320      self.skipTest(
321          'Skipping test because some F64 operations are '
322          'numerically unstable on TPU.'
323      )
324
325    x = np.random.uniform(low=0.1, high=50., size=[NUM_SAMPLES]).astype(dtype)
326    expected_values = sps.digamma(x)
327    with self.session() as sess:
328      with self.test_scope():
329        y = _polygamma(dtype(0.), x)
330      actual = sess.run(y)
331
332    self.assertAllClose(expected_values, actual, atol=atol, rtol=rtol)
333
334  @parameterized.parameters((np.float32, 1e-2, 1e-11),
335                            (np.float64, 1e-4, 1e-30))
336  def testSmallN(self, dtype, rtol, atol):
337    rtol, atol = self.adjust_tolerance_for_tpu(dtype, rtol, atol)
338    # Test values near zero.
339    n = np.random.randint(low=1, high=5, size=[NUM_SAMPLES]).astype(dtype)
340    x = np.random.uniform(
341        low=np.finfo(dtype).tiny, high=1., size=[NUM_SAMPLES]).astype(dtype)
342
343    expected_values = sps.polygamma(n, x)
344    with self.session() as sess:
345      with self.test_scope():
346        actual = sess.run(_polygamma(n, x))
347    self.assertAllClose(expected_values, actual, atol=atol, rtol=rtol)
348
349  @parameterized.parameters((np.float32, 1e-2, 1e-11),
350                            (np.float64, 1e-4, 1e-30))
351  def testMediumLargeN(self, dtype, rtol, atol):
352    rtol, atol = self.adjust_tolerance_for_tpu(dtype, rtol, atol)
353    n = np.random.randint(low=5, high=10, size=[NUM_SAMPLES]).astype(dtype)
354    x = np.random.uniform(low=1., high=1e1, size=[NUM_SAMPLES]).astype(dtype)
355
356    expected_values = sps.polygamma(n, x)
357    with self.session() as sess:
358      with self.test_scope():
359        actual = sess.run(_polygamma(n, x))
360    self.assertAllClose(expected_values, actual, atol=atol, rtol=rtol)
361
362
363class IgammaTest(xla_test.XLATestCase, parameterized.TestCase):
364
365  def setUp(self):
366    if flags.FLAGS.vary_seed:
367      entropy = os.urandom(64)
368      if six.PY2:
369        answer = int(entropy.encode('hex'), 16)
370      else:
371        answer = int.from_bytes(entropy, 'big')
372      np.random.seed(answer % (2**32 - 1))
373    super(IgammaTest, self).setUp()
374
375  # Skip Float64 test on TPU due to missing ops.
376  def maybe_skip_test(self, dtype):
377    if self.device not in ['XLA_GPU', 'XLA_CPU'] and dtype == np.float64:
378      self.skipTest(
379          'Skipping test because some F64 operations not supported on TPU.')
380
381  def adjust_tolerance_for_tpu(self, dtype, rtol, atol):
382    if self.device not in ['TPU']:
383      return rtol, atol
384
385    if dtype == np.float32:
386      return 2e-2, 1e-7
387    return 2e-4, 1e-20
388
389  @parameterized.parameters((np.float32, 1e-2, 1e-11),
390                            (np.float64, 1e-4, 1e-30))
391  def testLargeXSmallA(self, dtype, rtol, atol):
392    self.maybe_skip_test(dtype)
393    rtol, atol = self.adjust_tolerance_for_tpu(dtype, rtol, atol)
394    # Test values near zero.
395    x = np.random.uniform(low=100., high=200., size=[NUM_SAMPLES]).astype(dtype)
396    a = np.random.uniform(low=0.3, high=1., size=[NUM_SAMPLES]).astype(dtype)
397
398    expected_values = sps.gammainc(a, x)
399    with self.session() as sess:
400      with self.test_scope():
401        y = _igamma(a, x)
402      actual = sess.run(y)
403    self.assertAllClose(expected_values, actual, atol=atol, rtol=rtol)
404
405  @parameterized.parameters((np.float32, 1e-2, 1e-11),
406                            (np.float64, 1e-4, 1e-30))
407  def testSmallValues(self, dtype, rtol, atol):
408    self.maybe_skip_test(dtype)
409    rtol, atol = self.adjust_tolerance_for_tpu(dtype, rtol, atol)
410    # Test values near zero.
411    x = np.random.uniform(
412        low=np.finfo(dtype).tiny, high=1., size=[NUM_SAMPLES]).astype(dtype)
413    a = np.random.uniform(
414        low=np.finfo(dtype).tiny, high=1., size=[NUM_SAMPLES]).astype(dtype)
415
416    expected_values = sps.gammainc(a, x)
417    with self.session() as sess:
418      with self.test_scope():
419        actual = sess.run(_igamma(a, x))
420    self.assertAllClose(expected_values, actual, atol=atol, rtol=rtol)
421
422  @parameterized.parameters((np.float32, 1e-2, 1e-11),
423                            (np.float64, 1e-4, 1e-30))
424  def testMediumValues(self, dtype, rtol, atol):
425    self.maybe_skip_test(dtype)
426    rtol, atol = self.adjust_tolerance_for_tpu(dtype, rtol, atol)
427    # Test values near zero.
428    x = np.random.uniform(low=1., high=100., size=[NUM_SAMPLES]).astype(dtype)
429    a = np.random.uniform(low=1., high=100., size=[NUM_SAMPLES]).astype(dtype)
430
431    expected_values = sps.gammainc(a, x)
432    with self.session() as sess:
433      with self.test_scope():
434        actual = sess.run(_igamma(a, x))
435    self.assertAllClose(expected_values, actual, atol=atol, rtol=rtol)
436
437  @parameterized.parameters((np.float32, 2e-2, 1e-5), (np.float64, 1e-4, 1e-30))
438  def testLargeValues(self, dtype, rtol, atol):
439    if self.device == 'TPU':
440      # TODO(b/154908275): Remove this once fixed for large a, x.
441      self.skipTest('Skipping test since numerically unstable on TPU.')
442    # Test values near zero.
443    x = np.random.uniform(
444        low=100., high=int(1e4), size=[NUM_SAMPLES]).astype(dtype)
445    a = np.random.uniform(
446        low=100., high=int(1e4), size=[NUM_SAMPLES]).astype(dtype)
447
448    expected_values = sps.gammainc(a, x)
449    with self.session() as sess:
450      with self.test_scope():
451        actual = sess.run(_igamma(a, x))
452    self.assertAllClose(expected_values, actual, atol=atol, rtol=rtol)
453
454  # We don't check small values because the numerical gradients become quite
455  # large.
456  @parameterized.parameters((np.float32, 0.09), (np.float64, 1e-7))
457  def testGradMediumValues(self, dtype, tolerance):
458    self.maybe_skip_test(dtype)
459    with self.session():
460      with self.test_scope():
461        x = constant_op.constant(
462            np.random.uniform(low=1., high=100.,
463                              size=[NUM_SAMPLES]).astype(dtype))
464        a = constant_op.constant(
465            np.random.uniform(low=1., high=100.,
466                              size=[NUM_SAMPLES]).astype(dtype))
467
468        f = lambda b: _igamma(b, x)
469        max_error = gradient_checker_v2.max_error(
470            *gradient_checker_v2.compute_gradient(f, x=[a], delta=1e-3))
471    self.assertLessEqual(max_error, tolerance)
472
473  @parameterized.parameters((np.float32, 0.5), (np.float64, 1e-7))
474  def testGradLargeValues(self, dtype, tolerance):
475    self.maybe_skip_test(dtype)
476    with self.session():
477      with self.test_scope():
478        x = constant_op.constant(
479            np.random.uniform(low=100., high=int(1e4),
480                              size=[NUM_SAMPLES]).astype(dtype))
481        a = constant_op.constant(
482            np.random.uniform(low=100., high=int(1e4),
483                              size=[NUM_SAMPLES]).astype(dtype))
484
485        f = lambda b: _igamma(b, x)
486        max_error = gradient_checker_v2.max_error(
487            *gradient_checker_v2.compute_gradient(f, x=[a], delta=1e-2))
488    self.assertLessEqual(max_error, tolerance)
489
490  @parameterized.parameters((np.float32, 1e-2, 1e-11),
491                            (np.float64, 1e-4, 1e-30))
492  def testRandomGammaGradSmallValues(self, dtype, rtol, atol):
493    self.maybe_skip_test(dtype)
494    rtol, atol = self.adjust_tolerance_for_tpu(dtype, rtol, atol)
495    # Test values near zero.
496
497    with self.session() as sess:
498      with self.test_scope():
499        x = constant_op.constant(
500            np.random.uniform(
501                low=np.finfo(dtype).tiny, high=1.,
502                size=[NUM_SAMPLES]).astype(dtype))
503        a = constant_op.constant(
504            np.random.uniform(
505                low=np.finfo(dtype).tiny, high=1.,
506                size=[NUM_SAMPLES]).astype(dtype))
507        gamma_sample_grad = gen_random_ops.random_gamma_grad(a, x)
508        actual_grad = implicit_reparameterization_grad(a, x)
509        gamma_sample_grad, actual_grad = sess.run(
510            [gamma_sample_grad, actual_grad])
511        # We do this because the ratio computed in
512        # implicit_reparameterization_grad can very easily result in a NaN due
513        # to the computed numerator and denominator zeroing out.
514        gamma_sample_grad = gamma_sample_grad[
515            ~np.logical_or(np.isnan(actual_grad), np.isinf(actual_grad))]
516        actual_grad = actual_grad[
517            ~np.logical_or(np.isnan(actual_grad), np.isinf(actual_grad))]
518    self.assertAllClose(actual_grad, gamma_sample_grad, atol=atol, rtol=rtol)
519
520  @parameterized.parameters((np.float32, 1e-2, 1e-11),
521                            (np.float64, 1e-4, 1e-30))
522  def testRandomGammaGradMediumValues(self, dtype, rtol, atol):
523    self.maybe_skip_test(dtype)
524    rtol, atol = self.adjust_tolerance_for_tpu(dtype, rtol, atol)
525
526    with self.session() as sess:
527      with self.test_scope():
528        x = constant_op.constant(
529            np.random.uniform(low=1., high=10.,
530                              size=[NUM_SAMPLES]).astype(dtype))
531        a = constant_op.constant(
532            np.random.uniform(low=1., high=10.,
533                              size=[NUM_SAMPLES]).astype(dtype))
534        gamma_sample_grad = gen_random_ops.random_gamma_grad(a, x)
535        actual_grad = implicit_reparameterization_grad(a, x)
536        gamma_sample_grad, actual_grad = sess.run(
537            [gamma_sample_grad, actual_grad])
538        # We do this because the ratio computed in
539        # implicit_reparameterization_grad can very easily result in a NaN due
540        # to the computed numerator and denominator zeroing out.
541        gamma_sample_grad = gamma_sample_grad[
542            ~np.logical_or(np.isnan(actual_grad), np.isinf(actual_grad))]
543        actual_grad = actual_grad[
544            ~np.logical_or(np.isnan(actual_grad), np.isinf(actual_grad))]
545    self.assertAllClose(actual_grad, gamma_sample_grad, atol=atol, rtol=rtol)
546
547
548class IgammacTest(xla_test.XLATestCase, parameterized.TestCase):
549
550  def setUp(self):
551    if flags.FLAGS.vary_seed:
552      entropy = os.urandom(64)
553      if six.PY2:
554        answer = int(entropy.encode('hex'), 16)
555      else:
556        answer = int.from_bytes(entropy, 'big')
557      np.random.seed(answer % (2**32 - 1))
558    super(IgammacTest, self).setUp()
559
560  # Skip Float64 test on TPU due to missing ops.
561  def maybe_skip_test(self, dtype):
562    if self.device not in ['XLA_GPU', 'XLA_CPU'] and dtype == np.float64:
563      # TODO(b/154908275): Remove this once fixed for large a, x.
564      self.skipTest(
565          'Skipping test because some F64 operations not supported on TPU.')
566
567  def adjust_tolerance_for_tpu(self, dtype, rtol, atol):
568    if self.device not in ['TPU']:
569      return rtol, atol
570
571    if dtype == np.float32:
572      return 2e-2, 1e-7
573    return 2e-4, 1e-20
574
575  @parameterized.parameters((np.float32, 1e-2, 1e-11),
576                            (np.float64, 1e-4, 1e-30))
577  def testLargeXSmallA(self, dtype, rtol, atol):
578    self.maybe_skip_test(dtype)
579    rtol, atol = self.adjust_tolerance_for_tpu(dtype, rtol, atol)
580    # Test values near zero.
581    x = np.random.uniform(low=100., high=200., size=[NUM_SAMPLES]).astype(dtype)
582    a = np.random.uniform(low=0.3, high=1., size=[NUM_SAMPLES]).astype(dtype)
583
584    expected_values = sps.gammaincc(a, x)
585    with self.session() as sess:
586      with self.test_scope():
587        y = _igammac(a, x)
588      actual = sess.run(y)
589    self.assertAllClose(expected_values, actual, atol=atol, rtol=rtol)
590
591  @parameterized.parameters((np.float32, 1e-2, 1e-11),
592                            (np.float64, 1e-4, 1e-30))
593  def testSmallValues(self, dtype, rtol, atol):
594    self.maybe_skip_test(dtype)
595    rtol, atol = self.adjust_tolerance_for_tpu(dtype, rtol, atol)
596    # Test values near zero.
597    x = np.random.uniform(
598        low=np.finfo(dtype).tiny, high=1., size=[NUM_SAMPLES]).astype(dtype)
599    a = np.random.uniform(
600        low=np.finfo(dtype).tiny, high=1., size=[NUM_SAMPLES]).astype(dtype)
601
602    expected_values = sps.gammaincc(a, x)
603    with self.session() as sess:
604      with self.test_scope():
605        actual = sess.run(_igammac(a, x))
606    self.assertAllClose(expected_values, actual, atol=atol, rtol=rtol)
607
608  @parameterized.parameters((np.float32, 1e-2, 1e-11),
609                            (np.float64, 1e-4, 1e-30))
610  def testMediumValues(self, dtype, rtol, atol):
611    self.maybe_skip_test(dtype)
612    rtol, atol = self.adjust_tolerance_for_tpu(dtype, rtol, atol)
613    # Test values near zero.
614    x = np.random.uniform(low=1., high=100., size=[NUM_SAMPLES]).astype(dtype)
615    a = np.random.uniform(low=1., high=100., size=[NUM_SAMPLES]).astype(dtype)
616
617    expected_values = sps.gammaincc(a, x)
618    with self.session() as sess:
619      with self.test_scope():
620        actual = sess.run(_igammac(a, x))
621    self.assertAllClose(expected_values, actual, atol=atol, rtol=rtol)
622
623  @parameterized.parameters((np.float32, 2e-2, 1e-5), (np.float64, 1e-4, 1e-30))
624  def testLargeValues(self, dtype, rtol, atol):
625    if self.device == 'TPU':
626      self.skipTest('Skipping test since numerically unstable on TPU.')
627    # Test values near zero.
628    x = np.random.uniform(
629        low=100., high=int(1e4), size=[NUM_SAMPLES]).astype(dtype)
630    a = np.random.uniform(
631        low=100., high=int(1e4), size=[NUM_SAMPLES]).astype(dtype)
632
633    expected_values = sps.gammaincc(a, x)
634    with self.session() as sess:
635      with self.test_scope():
636        actual = sess.run(_igammac(a, x))
637    self.assertAllClose(expected_values, actual, atol=atol, rtol=rtol)
638
639
640if __name__ == '__main__':
641  os.environ['XLA_FLAGS'] = '--xla_cpu_enable_fast_math=false'
642  test.main()
643