1# Copyright 2020 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15 16"""Tests for special math operations.""" 17 18from __future__ import absolute_import 19from __future__ import division 20from __future__ import print_function 21 22import os 23 24from absl import flags 25from absl.testing import parameterized 26 27import numpy as np 28import scipy.special as sps 29import six 30 31from tensorflow.compiler.tests import xla_test 32from tensorflow.python.eager import def_function 33from tensorflow.python.framework import constant_op 34from tensorflow.python.ops import gen_math_ops 35from tensorflow.python.ops import gen_random_ops 36from tensorflow.python.ops import gradient_checker_v2 37from tensorflow.python.ops import math_ops 38from tensorflow.python.platform import test 39 40flags.DEFINE_bool('vary_seed', False, 41 ('Whether to vary the PRNG seed unpredictably. ' 42 'With --runs_per_test=N, produces N iid runs.')) 43 44NUM_SAMPLES = int(1e3) 45 46 47@def_function.function(jit_compile=True) 48def _igamma(a, x): 49 return math_ops.igamma(a, x) 50 51 52@def_function.function(jit_compile=True) 53def _igammac(a, x): 54 return math_ops.igammac(a, x) 55 56 57@def_function.function(jit_compile=True) 58def _polygamma(n, x): 59 return math_ops.polygamma(n, x) 60 61 62@def_function.function(jit_compile=True) 63def _zeta(a, q): 64 return math_ops.zeta(a, q) 65 66 67# This is df/da / df/dx, where f = igamma. 68def implicit_reparameterization_grad(a, x): 69 log_prob = math_ops.xlogy(a - 1., x) - math_ops.lgamma(a) - x 70 prob = math_ops.exp(log_prob) 71 return -gen_math_ops.igamma_grad_a(a, x) / prob 72 73 74@def_function.function(jit_compile=True) 75def _log1p(x): 76 return math_ops.log1p(x) 77 78 79class Log1pTest(xla_test.XLATestCase, parameterized.TestCase): 80 81 def setUp(self): 82 if flags.FLAGS.vary_seed: 83 entropy = os.urandom(64) 84 if six.PY2: 85 answer = int(entropy.encode('hex'), 16) 86 else: 87 answer = int.from_bytes(entropy, 'big') 88 np.random.seed(answer % (2**32 - 1)) 89 super(Log1pTest, self).setUp() 90 91 def adjust_tolerance_for_tpu(self, dtype, rtol, atol): 92 if self.device not in ['TPU']: 93 return rtol, atol 94 95 if dtype == np.float32: 96 return 4e-4, 0. 97 return 1e-10, 0. 98 99 def _test_range(self, low, high, dtype, rtol, atol, is_negative=False): 100 # Test values near zero. 101 rtol, atol = self.adjust_tolerance_for_tpu(dtype, rtol, atol) 102 x = np.exp(np.random.uniform( 103 low=low, high=high, size=[NUM_SAMPLES])).astype(dtype) 104 if is_negative: 105 x = -x 106 expected_values = np.log1p(x) 107 with self.session() as sess: 108 with self.test_scope(): 109 actual = _log1p(x) 110 actual = sess.run(actual) 111 self.assertAllClose(expected_values, actual, atol=atol, rtol=rtol) 112 113 @parameterized.parameters((np.float32, 1e-7, 0.), 114 (np.float64, 1e-15, 0.)) 115 def testSmallX(self, dtype, rtol, atol): 116 self._test_range(-40., -20., dtype, rtol, atol, is_negative=False) 117 self._test_range(-40., -20., dtype, rtol, atol, is_negative=True) 118 119 @parameterized.parameters((np.float32, 2e-7, 0.), 120 (np.float64, 1e-15, 0.)) 121 def testGreaterThanNegativeTwentyExponent(self, dtype, rtol, atol): 122 self._test_range(-20., -10., dtype, rtol, atol, is_negative=False) 123 self._test_range(-20., -10., dtype, rtol, atol, is_negative=True) 124 125 @parameterized.parameters((np.float32, 2e-7, 0.), 126 (np.float64, 1e-15, 0.)) 127 def testGreaterThanNegativeTenExponent(self, dtype, rtol, atol): 128 self._test_range(-10., -5., dtype, rtol, atol, is_negative=False) 129 self._test_range(-10., -5., dtype, rtol, atol, is_negative=True) 130 131 @parameterized.parameters((np.float32, 2e-7, 0.), 132 (np.float64, 1e-15, 0.)) 133 def testGreaterThanNegativeFiveExponent(self, dtype, rtol, atol): 134 self._test_range(-5., -1., dtype, rtol, atol, is_negative=False) 135 self._test_range(-5., -1., dtype, rtol, atol, is_negative=True) 136 137 @parameterized.parameters((np.float32, 4e-7, 0.), 138 (np.float64, 3e-14, 0.)) 139 def testXGreaterThanOneTenth(self, dtype, rtol, atol): 140 self._test_range(-1., 0., dtype, rtol, atol, is_negative=False) 141 self._test_range(-1., 0., dtype, rtol, atol, is_negative=True) 142 143 @parameterized.parameters((np.float32, 2e-7, 0.), 144 (np.float64, 2e-15, 0.)) 145 def testXGreaterThanOne(self, dtype, rtol, atol): 146 self._test_range(0., 3., dtype, rtol, atol, is_negative=False) 147 148 149class ZetaTest(xla_test.XLATestCase, parameterized.TestCase): 150 151 def setUp(self): 152 if flags.FLAGS.vary_seed: 153 entropy = os.urandom(64) 154 if six.PY2: 155 answer = int(entropy.encode('hex'), 16) 156 else: 157 answer = int.from_bytes(entropy, 'big') 158 np.random.seed(answer % (2**32 - 1)) 159 super(ZetaTest, self).setUp() 160 161 def adjust_tolerance_for_tpu(self, dtype, rtol, atol): 162 if self.device not in ['TPU']: 163 return rtol, atol 164 165 if dtype == np.float32: 166 return 2e-2, 1e-7 167 return 2e-4, 1e-20 168 169 def testBadValues(self): 170 q = np.random.uniform(low=0.3, high=20., size=[10]) 171 with self.session() as sess: 172 with self.test_scope(): 173 y = _zeta(np.float64(1.), q) 174 actual = sess.run(y) 175 # When x == 1, this is the Harmonic series. 176 self.assertTrue(np.all(np.isinf(actual))) 177 178 with self.session() as sess: 179 with self.test_scope(): 180 y = _zeta(np.float64(0.1), q) 181 actual = sess.run(y) 182 # When x < 1, this is undefined. 183 self.assertTrue(np.all(np.isnan(actual))) 184 185 with self.session() as sess: 186 with self.test_scope(): 187 y = _zeta([1.1, 1.2, 2.1, 2.2, 3.1], [-2.0, -1.1, -1.0, -0.5, -0.1]) 188 actual = sess.run(y) 189 # For q <= 0, x must be an integer. 190 self.assertTrue(np.all(np.isnan(actual))) 191 192 with self.session() as sess: 193 with self.test_scope(): 194 y = _zeta([2.0, 4.0, 6.0], [0.0, -1.0, -2.0]) 195 actual = sess.run(y) 196 # For integer q <= 0, zeta has poles with a defined limit of +inf where x is 197 # an even integer. 198 self.assertTrue(np.all(np.isinf(actual))) 199 200 with self.session() as sess: 201 with self.test_scope(): 202 y = _zeta([3.0, 5.0, 7.0], [0.0, -1.0, -2.0]) 203 actual = sess.run(y) 204 # For non-positive integer q, zeta has poles with an undefined limit where x 205 # is an odd integer. 206 self.assertTrue(np.all(np.isnan(actual))) 207 208 with self.session() as sess: 209 with self.test_scope(): 210 y = _zeta([1.1, 2.2, 3.3], [-1.1, -1.0, 0.0]) 211 actual = sess.run(y) 212 # For non-positive q, zeta is not defined if x is not an integer. 213 self.assertTrue(np.all(np.isnan(actual))) 214 215 @parameterized.parameters((np.float32, 1e-2, 1e-11), 216 (np.float64, 1e-4, 1e-30)) 217 def testLargeXSmallQ(self, dtype, rtol, atol): 218 rtol, atol = self.adjust_tolerance_for_tpu(dtype, rtol, atol) 219 if self.device not in ['XLA_GPU', 'XLA_CPU'] and dtype == np.float64: 220 # TODO(b/165739664): Figure out why on TPU F64 Zeta sometimes returns 221 # infs. 222 self.skipTest( 223 'Skipping test because some F64 operations are numerically ' 224 'unstable on TPU.') 225 226 x = np.random.uniform(low=100., high=200., size=[NUM_SAMPLES]).astype(dtype) 227 q = np.random.uniform(low=0.3, high=1., size=[NUM_SAMPLES]).astype(dtype) 228 229 expected_values = sps.zeta(x, q) 230 with self.session() as sess: 231 with self.test_scope(): 232 y = _zeta(x, q) 233 actual = sess.run(y) 234 235 self.assertAllClose(expected_values, actual, atol=atol, rtol=rtol) 236 237 @parameterized.parameters((np.float32, 1e-2, 1e-11), 238 (np.float64, 1e-4, 1e-30)) 239 def testSmallValues(self, dtype, rtol, atol): 240 rtol, atol = self.adjust_tolerance_for_tpu(dtype, rtol, atol) 241 # Test values near zero. 242 x = np.random.uniform(low=1.1, high=10., size=[NUM_SAMPLES]).astype(dtype) 243 q = np.random.uniform( 244 low=np.finfo(dtype).tiny, high=1., size=[NUM_SAMPLES]).astype(dtype) 245 246 expected_values = sps.zeta(x, q) 247 with self.session() as sess: 248 with self.test_scope(): 249 actual = sess.run(_zeta(x, q)) 250 self.assertAllClose(expected_values, actual, atol=atol, rtol=rtol) 251 252 @parameterized.parameters((np.float32, 1e-2, 1e-11), 253 (np.float64, 1e-4, 1e-30)) 254 def testMediumValues(self, dtype, rtol, atol): 255 rtol, atol = self.adjust_tolerance_for_tpu(dtype, rtol, atol) 256 x = np.random.uniform(low=1.1, high=100., size=[NUM_SAMPLES]).astype(dtype) 257 q = np.random.uniform(low=1., high=1e1, size=[NUM_SAMPLES]).astype(dtype) 258 259 expected_values = sps.zeta(x, q) 260 with self.session() as sess: 261 with self.test_scope(): 262 actual = sess.run(_zeta(x, q)) 263 self.assertAllClose(expected_values, actual, atol=atol, rtol=rtol) 264 265 @parameterized.parameters((np.float32, 2e-2, 1e-5), (np.float64, 1e-4, 1e-30)) 266 def testLargeValues(self, dtype, rtol, atol): 267 x = np.random.uniform( 268 low=100., high=int(1e3), size=[NUM_SAMPLES]).astype(dtype) 269 q = np.random.uniform( 270 low=1., high=int(1e1), size=[NUM_SAMPLES]).astype(dtype) 271 272 expected_values = sps.zeta(x, q) 273 with self.session() as sess: 274 with self.test_scope(): 275 actual = sess.run(_zeta(x, q)) 276 self.assertAllClose(expected_values, actual, atol=atol, rtol=rtol) 277 278 279class PolygammaTest(xla_test.XLATestCase, parameterized.TestCase): 280 281 def setUp(self): 282 if flags.FLAGS.vary_seed: 283 entropy = os.urandom(64) 284 if six.PY2: 285 answer = int(entropy.encode('hex'), 16) 286 else: 287 answer = int.from_bytes(entropy, 'big') 288 np.random.seed(answer % (2**32 - 1)) 289 super(PolygammaTest, self).setUp() 290 291 def adjust_tolerance_for_tpu(self, dtype, rtol, atol): 292 if self.device not in ['TPU']: 293 return rtol, atol 294 295 if dtype == np.float32: 296 return 2e-2, 1e-7 297 return 2e-4, 1e-20 298 299 def testBadValues(self): 300 x = np.random.uniform(low=0.3, high=20., size=[10]) 301 with self.session() as sess: 302 with self.test_scope(): 303 y = _polygamma(np.float64(-1.), x) 304 actual = sess.run(y) 305 # Not defined for negative numbers. 306 self.assertTrue(np.all(np.isnan(actual))) 307 308 with self.session() as sess: 309 with self.test_scope(): 310 y = _polygamma(np.float64(0.1), x) 311 actual = sess.run(y) 312 # Not defined for non-integers. 313 self.assertTrue(np.all(np.isnan(actual))) 314 315 @parameterized.parameters((np.float32, 1e-2, 1e-11), 316 (np.float64, 1e-4, 1e-30)) 317 def testRecoverDigamma(self, dtype, rtol, atol): 318 rtol, atol = self.adjust_tolerance_for_tpu(dtype, rtol, atol) 319 if self.device not in ['XLA_GPU', 'XLA_CPU'] and dtype == np.float64: 320 self.skipTest( 321 'Skipping test because some F64 operations are ' 322 'numerically unstable on TPU.' 323 ) 324 325 x = np.random.uniform(low=0.1, high=50., size=[NUM_SAMPLES]).astype(dtype) 326 expected_values = sps.digamma(x) 327 with self.session() as sess: 328 with self.test_scope(): 329 y = _polygamma(dtype(0.), x) 330 actual = sess.run(y) 331 332 self.assertAllClose(expected_values, actual, atol=atol, rtol=rtol) 333 334 @parameterized.parameters((np.float32, 1e-2, 1e-11), 335 (np.float64, 1e-4, 1e-30)) 336 def testSmallN(self, dtype, rtol, atol): 337 rtol, atol = self.adjust_tolerance_for_tpu(dtype, rtol, atol) 338 # Test values near zero. 339 n = np.random.randint(low=1, high=5, size=[NUM_SAMPLES]).astype(dtype) 340 x = np.random.uniform( 341 low=np.finfo(dtype).tiny, high=1., size=[NUM_SAMPLES]).astype(dtype) 342 343 expected_values = sps.polygamma(n, x) 344 with self.session() as sess: 345 with self.test_scope(): 346 actual = sess.run(_polygamma(n, x)) 347 self.assertAllClose(expected_values, actual, atol=atol, rtol=rtol) 348 349 @parameterized.parameters((np.float32, 1e-2, 1e-11), 350 (np.float64, 1e-4, 1e-30)) 351 def testMediumLargeN(self, dtype, rtol, atol): 352 rtol, atol = self.adjust_tolerance_for_tpu(dtype, rtol, atol) 353 n = np.random.randint(low=5, high=10, size=[NUM_SAMPLES]).astype(dtype) 354 x = np.random.uniform(low=1., high=1e1, size=[NUM_SAMPLES]).astype(dtype) 355 356 expected_values = sps.polygamma(n, x) 357 with self.session() as sess: 358 with self.test_scope(): 359 actual = sess.run(_polygamma(n, x)) 360 self.assertAllClose(expected_values, actual, atol=atol, rtol=rtol) 361 362 363class IgammaTest(xla_test.XLATestCase, parameterized.TestCase): 364 365 def setUp(self): 366 if flags.FLAGS.vary_seed: 367 entropy = os.urandom(64) 368 if six.PY2: 369 answer = int(entropy.encode('hex'), 16) 370 else: 371 answer = int.from_bytes(entropy, 'big') 372 np.random.seed(answer % (2**32 - 1)) 373 super(IgammaTest, self).setUp() 374 375 # Skip Float64 test on TPU due to missing ops. 376 def maybe_skip_test(self, dtype): 377 if self.device not in ['XLA_GPU', 'XLA_CPU'] and dtype == np.float64: 378 self.skipTest( 379 'Skipping test because some F64 operations not supported on TPU.') 380 381 def adjust_tolerance_for_tpu(self, dtype, rtol, atol): 382 if self.device not in ['TPU']: 383 return rtol, atol 384 385 if dtype == np.float32: 386 return 2e-2, 1e-7 387 return 2e-4, 1e-20 388 389 @parameterized.parameters((np.float32, 1e-2, 1e-11), 390 (np.float64, 1e-4, 1e-30)) 391 def testLargeXSmallA(self, dtype, rtol, atol): 392 self.maybe_skip_test(dtype) 393 rtol, atol = self.adjust_tolerance_for_tpu(dtype, rtol, atol) 394 # Test values near zero. 395 x = np.random.uniform(low=100., high=200., size=[NUM_SAMPLES]).astype(dtype) 396 a = np.random.uniform(low=0.3, high=1., size=[NUM_SAMPLES]).astype(dtype) 397 398 expected_values = sps.gammainc(a, x) 399 with self.session() as sess: 400 with self.test_scope(): 401 y = _igamma(a, x) 402 actual = sess.run(y) 403 self.assertAllClose(expected_values, actual, atol=atol, rtol=rtol) 404 405 @parameterized.parameters((np.float32, 1e-2, 1e-11), 406 (np.float64, 1e-4, 1e-30)) 407 def testSmallValues(self, dtype, rtol, atol): 408 self.maybe_skip_test(dtype) 409 rtol, atol = self.adjust_tolerance_for_tpu(dtype, rtol, atol) 410 # Test values near zero. 411 x = np.random.uniform( 412 low=np.finfo(dtype).tiny, high=1., size=[NUM_SAMPLES]).astype(dtype) 413 a = np.random.uniform( 414 low=np.finfo(dtype).tiny, high=1., size=[NUM_SAMPLES]).astype(dtype) 415 416 expected_values = sps.gammainc(a, x) 417 with self.session() as sess: 418 with self.test_scope(): 419 actual = sess.run(_igamma(a, x)) 420 self.assertAllClose(expected_values, actual, atol=atol, rtol=rtol) 421 422 @parameterized.parameters((np.float32, 1e-2, 1e-11), 423 (np.float64, 1e-4, 1e-30)) 424 def testMediumValues(self, dtype, rtol, atol): 425 self.maybe_skip_test(dtype) 426 rtol, atol = self.adjust_tolerance_for_tpu(dtype, rtol, atol) 427 # Test values near zero. 428 x = np.random.uniform(low=1., high=100., size=[NUM_SAMPLES]).astype(dtype) 429 a = np.random.uniform(low=1., high=100., size=[NUM_SAMPLES]).astype(dtype) 430 431 expected_values = sps.gammainc(a, x) 432 with self.session() as sess: 433 with self.test_scope(): 434 actual = sess.run(_igamma(a, x)) 435 self.assertAllClose(expected_values, actual, atol=atol, rtol=rtol) 436 437 @parameterized.parameters((np.float32, 2e-2, 1e-5), (np.float64, 1e-4, 1e-30)) 438 def testLargeValues(self, dtype, rtol, atol): 439 if self.device == 'TPU': 440 # TODO(b/154908275): Remove this once fixed for large a, x. 441 self.skipTest('Skipping test since numerically unstable on TPU.') 442 # Test values near zero. 443 x = np.random.uniform( 444 low=100., high=int(1e4), size=[NUM_SAMPLES]).astype(dtype) 445 a = np.random.uniform( 446 low=100., high=int(1e4), size=[NUM_SAMPLES]).astype(dtype) 447 448 expected_values = sps.gammainc(a, x) 449 with self.session() as sess: 450 with self.test_scope(): 451 actual = sess.run(_igamma(a, x)) 452 self.assertAllClose(expected_values, actual, atol=atol, rtol=rtol) 453 454 # We don't check small values because the numerical gradients become quite 455 # large. 456 @parameterized.parameters((np.float32, 0.09), (np.float64, 1e-7)) 457 def testGradMediumValues(self, dtype, tolerance): 458 self.maybe_skip_test(dtype) 459 with self.session(): 460 with self.test_scope(): 461 x = constant_op.constant( 462 np.random.uniform(low=1., high=100., 463 size=[NUM_SAMPLES]).astype(dtype)) 464 a = constant_op.constant( 465 np.random.uniform(low=1., high=100., 466 size=[NUM_SAMPLES]).astype(dtype)) 467 468 f = lambda b: _igamma(b, x) 469 max_error = gradient_checker_v2.max_error( 470 *gradient_checker_v2.compute_gradient(f, x=[a], delta=1e-3)) 471 self.assertLessEqual(max_error, tolerance) 472 473 @parameterized.parameters((np.float32, 0.5), (np.float64, 1e-7)) 474 def testGradLargeValues(self, dtype, tolerance): 475 self.maybe_skip_test(dtype) 476 with self.session(): 477 with self.test_scope(): 478 x = constant_op.constant( 479 np.random.uniform(low=100., high=int(1e4), 480 size=[NUM_SAMPLES]).astype(dtype)) 481 a = constant_op.constant( 482 np.random.uniform(low=100., high=int(1e4), 483 size=[NUM_SAMPLES]).astype(dtype)) 484 485 f = lambda b: _igamma(b, x) 486 max_error = gradient_checker_v2.max_error( 487 *gradient_checker_v2.compute_gradient(f, x=[a], delta=1e-2)) 488 self.assertLessEqual(max_error, tolerance) 489 490 @parameterized.parameters((np.float32, 1e-2, 1e-11), 491 (np.float64, 1e-4, 1e-30)) 492 def testRandomGammaGradSmallValues(self, dtype, rtol, atol): 493 self.maybe_skip_test(dtype) 494 rtol, atol = self.adjust_tolerance_for_tpu(dtype, rtol, atol) 495 # Test values near zero. 496 497 with self.session() as sess: 498 with self.test_scope(): 499 x = constant_op.constant( 500 np.random.uniform( 501 low=np.finfo(dtype).tiny, high=1., 502 size=[NUM_SAMPLES]).astype(dtype)) 503 a = constant_op.constant( 504 np.random.uniform( 505 low=np.finfo(dtype).tiny, high=1., 506 size=[NUM_SAMPLES]).astype(dtype)) 507 gamma_sample_grad = gen_random_ops.random_gamma_grad(a, x) 508 actual_grad = implicit_reparameterization_grad(a, x) 509 gamma_sample_grad, actual_grad = sess.run( 510 [gamma_sample_grad, actual_grad]) 511 # We do this because the ratio computed in 512 # implicit_reparameterization_grad can very easily result in a NaN due 513 # to the computed numerator and denominator zeroing out. 514 gamma_sample_grad = gamma_sample_grad[ 515 ~np.logical_or(np.isnan(actual_grad), np.isinf(actual_grad))] 516 actual_grad = actual_grad[ 517 ~np.logical_or(np.isnan(actual_grad), np.isinf(actual_grad))] 518 self.assertAllClose(actual_grad, gamma_sample_grad, atol=atol, rtol=rtol) 519 520 @parameterized.parameters((np.float32, 1e-2, 1e-11), 521 (np.float64, 1e-4, 1e-30)) 522 def testRandomGammaGradMediumValues(self, dtype, rtol, atol): 523 self.maybe_skip_test(dtype) 524 rtol, atol = self.adjust_tolerance_for_tpu(dtype, rtol, atol) 525 526 with self.session() as sess: 527 with self.test_scope(): 528 x = constant_op.constant( 529 np.random.uniform(low=1., high=10., 530 size=[NUM_SAMPLES]).astype(dtype)) 531 a = constant_op.constant( 532 np.random.uniform(low=1., high=10., 533 size=[NUM_SAMPLES]).astype(dtype)) 534 gamma_sample_grad = gen_random_ops.random_gamma_grad(a, x) 535 actual_grad = implicit_reparameterization_grad(a, x) 536 gamma_sample_grad, actual_grad = sess.run( 537 [gamma_sample_grad, actual_grad]) 538 # We do this because the ratio computed in 539 # implicit_reparameterization_grad can very easily result in a NaN due 540 # to the computed numerator and denominator zeroing out. 541 gamma_sample_grad = gamma_sample_grad[ 542 ~np.logical_or(np.isnan(actual_grad), np.isinf(actual_grad))] 543 actual_grad = actual_grad[ 544 ~np.logical_or(np.isnan(actual_grad), np.isinf(actual_grad))] 545 self.assertAllClose(actual_grad, gamma_sample_grad, atol=atol, rtol=rtol) 546 547 548class IgammacTest(xla_test.XLATestCase, parameterized.TestCase): 549 550 def setUp(self): 551 if flags.FLAGS.vary_seed: 552 entropy = os.urandom(64) 553 if six.PY2: 554 answer = int(entropy.encode('hex'), 16) 555 else: 556 answer = int.from_bytes(entropy, 'big') 557 np.random.seed(answer % (2**32 - 1)) 558 super(IgammacTest, self).setUp() 559 560 # Skip Float64 test on TPU due to missing ops. 561 def maybe_skip_test(self, dtype): 562 if self.device not in ['XLA_GPU', 'XLA_CPU'] and dtype == np.float64: 563 # TODO(b/154908275): Remove this once fixed for large a, x. 564 self.skipTest( 565 'Skipping test because some F64 operations not supported on TPU.') 566 567 def adjust_tolerance_for_tpu(self, dtype, rtol, atol): 568 if self.device not in ['TPU']: 569 return rtol, atol 570 571 if dtype == np.float32: 572 return 2e-2, 1e-7 573 return 2e-4, 1e-20 574 575 @parameterized.parameters((np.float32, 1e-2, 1e-11), 576 (np.float64, 1e-4, 1e-30)) 577 def testLargeXSmallA(self, dtype, rtol, atol): 578 self.maybe_skip_test(dtype) 579 rtol, atol = self.adjust_tolerance_for_tpu(dtype, rtol, atol) 580 # Test values near zero. 581 x = np.random.uniform(low=100., high=200., size=[NUM_SAMPLES]).astype(dtype) 582 a = np.random.uniform(low=0.3, high=1., size=[NUM_SAMPLES]).astype(dtype) 583 584 expected_values = sps.gammaincc(a, x) 585 with self.session() as sess: 586 with self.test_scope(): 587 y = _igammac(a, x) 588 actual = sess.run(y) 589 self.assertAllClose(expected_values, actual, atol=atol, rtol=rtol) 590 591 @parameterized.parameters((np.float32, 1e-2, 1e-11), 592 (np.float64, 1e-4, 1e-30)) 593 def testSmallValues(self, dtype, rtol, atol): 594 self.maybe_skip_test(dtype) 595 rtol, atol = self.adjust_tolerance_for_tpu(dtype, rtol, atol) 596 # Test values near zero. 597 x = np.random.uniform( 598 low=np.finfo(dtype).tiny, high=1., size=[NUM_SAMPLES]).astype(dtype) 599 a = np.random.uniform( 600 low=np.finfo(dtype).tiny, high=1., size=[NUM_SAMPLES]).astype(dtype) 601 602 expected_values = sps.gammaincc(a, x) 603 with self.session() as sess: 604 with self.test_scope(): 605 actual = sess.run(_igammac(a, x)) 606 self.assertAllClose(expected_values, actual, atol=atol, rtol=rtol) 607 608 @parameterized.parameters((np.float32, 1e-2, 1e-11), 609 (np.float64, 1e-4, 1e-30)) 610 def testMediumValues(self, dtype, rtol, atol): 611 self.maybe_skip_test(dtype) 612 rtol, atol = self.adjust_tolerance_for_tpu(dtype, rtol, atol) 613 # Test values near zero. 614 x = np.random.uniform(low=1., high=100., size=[NUM_SAMPLES]).astype(dtype) 615 a = np.random.uniform(low=1., high=100., size=[NUM_SAMPLES]).astype(dtype) 616 617 expected_values = sps.gammaincc(a, x) 618 with self.session() as sess: 619 with self.test_scope(): 620 actual = sess.run(_igammac(a, x)) 621 self.assertAllClose(expected_values, actual, atol=atol, rtol=rtol) 622 623 @parameterized.parameters((np.float32, 2e-2, 1e-5), (np.float64, 1e-4, 1e-30)) 624 def testLargeValues(self, dtype, rtol, atol): 625 if self.device == 'TPU': 626 self.skipTest('Skipping test since numerically unstable on TPU.') 627 # Test values near zero. 628 x = np.random.uniform( 629 low=100., high=int(1e4), size=[NUM_SAMPLES]).astype(dtype) 630 a = np.random.uniform( 631 low=100., high=int(1e4), size=[NUM_SAMPLES]).astype(dtype) 632 633 expected_values = sps.gammaincc(a, x) 634 with self.session() as sess: 635 with self.test_scope(): 636 actual = sess.run(_igammac(a, x)) 637 self.assertAllClose(expected_values, actual, atol=atol, rtol=rtol) 638 639 640if __name__ == '__main__': 641 os.environ['XLA_FLAGS'] = '--xla_cpu_enable_fast_math=false' 642 test.main() 643