1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for Python ops defined in math_grad.py.""" 16 17import numpy as np 18 19from tensorflow.python.eager import backprop 20from tensorflow.python.eager import context 21from tensorflow.python.framework import constant_op 22from tensorflow.python.framework import dtypes 23from tensorflow.python.framework import ops 24from tensorflow.python.framework import test_util 25from tensorflow.python.ops import array_ops 26from tensorflow.python.ops import gradient_checker 27from tensorflow.python.ops import gradient_checker_v2 28from tensorflow.python.ops import gradients 29from tensorflow.python.ops import math_ops 30from tensorflow.python.platform import test 31 32 33class SquaredDifferenceOpTest(test.TestCase): 34 35 def _testGrad(self, left_shape, right_shape): 36 37 if len(left_shape) > len(right_shape): 38 output_shape = left_shape 39 else: 40 output_shape = right_shape 41 l = np.random.randn(*left_shape) 42 r = np.random.randn(*right_shape) 43 44 with self.cached_session(): 45 left_tensor = constant_op.constant(l, shape=left_shape) 46 right_tensor = constant_op.constant(r, shape=right_shape) 47 output = math_ops.squared_difference(left_tensor, right_tensor) 48 left_err = gradient_checker.compute_gradient_error( 49 left_tensor, left_shape, output, output_shape, x_init_value=l) 50 right_err = gradient_checker.compute_gradient_error( 51 right_tensor, right_shape, output, output_shape, x_init_value=r) 52 self.assertLess(left_err, 1e-10) 53 self.assertLess(right_err, 1e-10) 54 55 @test_util.run_deprecated_v1 56 def testGrad(self): 57 self._testGrad([1, 2, 3, 2], [3, 2]) 58 self._testGrad([2, 4], [3, 2, 4]) 59 60 61class AbsOpTest(test.TestCase): 62 63 def _biasedRandN(self, shape, bias=0.1, sigma=1.0): 64 """Returns samples from a normal distribution shifted `bias` away from 0.""" 65 value = np.random.randn(*shape) * sigma 66 return value + np.sign(value) * bias 67 68 def _testGrad(self, shape, dtype=None, max_error=None, bias=None, sigma=None): 69 np.random.seed(7) 70 if dtype in (dtypes.complex64, dtypes.complex128): 71 value = math_ops.complex( 72 self._biasedRandN( 73 shape, bias=bias, sigma=sigma), 74 self._biasedRandN( 75 shape, bias=bias, sigma=sigma)) 76 else: 77 value = ops.convert_to_tensor( 78 self._biasedRandN( 79 shape, bias=bias), dtype=dtype) 80 81 with self.cached_session(): 82 output = math_ops.abs(value) 83 error = gradient_checker.compute_gradient_error( 84 value, shape, output, output.get_shape().as_list()) 85 self.assertLess(error, max_error) 86 87 @test_util.run_deprecated_v1 88 def testComplexAbs(self): 89 # Bias random test values away from zero to avoid numeric instabilities. 90 self._testGrad( 91 [3, 3], dtype=dtypes.float32, max_error=2e-5, bias=0.1, sigma=1.0) 92 self._testGrad( 93 [3, 3], dtype=dtypes.complex64, max_error=2e-5, bias=0.1, sigma=1.0) 94 95 # Ensure stability near the pole at zero. 96 self._testGrad( 97 [3, 3], dtype=dtypes.float32, max_error=100.0, bias=0.0, sigma=0.1) 98 self._testGrad( 99 [3, 3], dtype=dtypes.complex64, max_error=100.0, bias=0.0, sigma=0.1) 100 101 102class MinOrMaxGradientTest(test.TestCase): 103 104 @test_util.run_deprecated_v1 105 def testMinGradient(self): 106 inputs = constant_op.constant([1.0], dtype=dtypes.float32) 107 outputs = math_ops.reduce_min(array_ops.concat([inputs, inputs], 0)) 108 with self.cached_session(): 109 error = gradient_checker.compute_gradient_error(inputs, [1], outputs, []) 110 self.assertLess(error, 1e-4) 111 112 @test_util.run_deprecated_v1 113 def testMaxGradient(self): 114 inputs = constant_op.constant([1.0], dtype=dtypes.float32) 115 outputs = math_ops.reduce_max(array_ops.concat([inputs, inputs], 0)) 116 with self.cached_session(): 117 error = gradient_checker.compute_gradient_error(inputs, [1], outputs, []) 118 self.assertLess(error, 1e-4) 119 120 121class MaximumOrMinimumGradientTest(test.TestCase): 122 123 @test_util.run_deprecated_v1 124 def testMaximumGradient(self): 125 inputs = constant_op.constant([1.0, 2.0, 3.0, 4.0], dtype=dtypes.float32) 126 outputs = math_ops.maximum(inputs, 3.0) 127 with self.cached_session(): 128 error = gradient_checker.compute_gradient_error(inputs, [4], outputs, [4]) 129 self.assertLess(error, 1e-4) 130 131 @test_util.run_deprecated_v1 132 def testMinimumGradient(self): 133 inputs = constant_op.constant([1.0, 2.0, 3.0, 4.0], dtype=dtypes.float32) 134 outputs = math_ops.minimum(inputs, 2.0) 135 with self.cached_session(): 136 error = gradient_checker.compute_gradient_error(inputs, [4], outputs, [4]) 137 self.assertLess(error, 1e-4) 138 139 140class ProdGradientTest(test.TestCase): 141 142 @test_util.run_deprecated_v1 143 def testProdGradient(self): 144 inputs = constant_op.constant([[1., 2.], [3., 4.]], 145 dtype=dtypes.float32) 146 outputs = math_ops.reduce_prod(inputs) 147 with self.cached_session(): 148 error = gradient_checker.compute_gradient_error( 149 inputs, inputs.get_shape().as_list(), 150 outputs, outputs.get_shape().as_list()) 151 self.assertLess(error, 1e-4) 152 153 @test_util.run_deprecated_v1 154 def testProdGradientForNegativeAxis(self): 155 inputs = constant_op.constant([[1., 2.], [3., 4.]], 156 dtype=dtypes.float32) 157 outputs = math_ops.reduce_prod(inputs, -1) 158 with self.cached_session(): 159 error = gradient_checker.compute_gradient_error( 160 inputs, inputs.get_shape().as_list(), 161 outputs, outputs.get_shape().as_list()) 162 self.assertLess(error, 1e-4) 163 164 @test_util.run_deprecated_v1 165 def testProdGradientComplex(self): 166 for dtype in dtypes.complex64, dtypes.complex128: 167 inputs = constant_op.constant([[1 + 3j, 2 - 1j], [3j, 4]], 168 dtype=dtype) 169 outputs = math_ops.reduce_prod(inputs) 170 with self.cached_session(): 171 error = gradient_checker.compute_gradient_error( 172 inputs, inputs.get_shape().as_list(), 173 outputs, outputs.get_shape().as_list()) 174 self.assertLess(error, 1e-4) 175 176 @test_util.run_deprecated_v1 177 def testProdGradientForNegativeAxisComplex(self): 178 for dtype in dtypes.complex64, dtypes.complex128: 179 inputs = constant_op.constant([[1 + 3j, 2 - 1j], [3j, 4]], 180 dtype=dtype) 181 outputs = math_ops.reduce_prod(inputs, -1) 182 with self.cached_session(): 183 error = gradient_checker.compute_gradient_error( 184 inputs, inputs.get_shape().as_list(), 185 outputs, outputs.get_shape().as_list()) 186 self.assertLess(error, 1e-4) 187 188 189@test_util.run_all_in_graph_and_eager_modes 190class EuclideanNormGradientTest(test.TestCase): 191 192 def testBasic(self): 193 for dtype in [dtypes.float32, dtypes.float64]: 194 x = constant_op.constant([3], dtype=dtype) 195 grad = gradient_checker_v2.compute_gradient( 196 math_ops.reduce_euclidean_norm, [x]) 197 err = gradient_checker_v2.max_error(*grad) 198 self.assertLess(err, 1e-3) 199 200 def testNegative(self): 201 for dtype in [dtypes.float32, dtypes.float64]: 202 x = constant_op.constant([-3], dtype=dtype) 203 grad = gradient_checker_v2.compute_gradient( 204 math_ops.reduce_euclidean_norm, [x]) 205 err = gradient_checker_v2.max_error(*grad) 206 self.assertLess(err, 1e-3) 207 208 def testKeepdims(self): 209 for dtype in [dtypes.float32, dtypes.float64]: 210 x = constant_op.constant([3], dtype=dtype) 211 grad = gradient_checker_v2.compute_gradient( 212 math_ops.reduce_euclidean_norm, [x]) 213 err = gradient_checker_v2.max_error(*grad) 214 self.assertLess(err, 1e-3) 215 216 def testGradientChain(self): 217 for dtype in [dtypes.float32, dtypes.float64]: 218 x = constant_op.constant([3], dtype=dtype) 219 grad = gradient_checker_v2.compute_gradient( 220 lambda x: math_ops.reduce_euclidean_norm(x) * 5, [x]) 221 err = gradient_checker_v2.max_error(*grad) 222 self.assertLess(err, 1e-3) 223 224 def testTwoElements(self): 225 for dtype in [dtypes.float32, dtypes.float64]: 226 x = constant_op.constant([3, -4], dtype=dtype) 227 grad = gradient_checker_v2.compute_gradient( 228 math_ops.reduce_euclidean_norm, [x]) 229 err = gradient_checker_v2.max_error(*grad) 230 self.assertLess(err, 1e-3) 231 232 def testNegativeZero(self): 233 for dtype in [dtypes.float32, dtypes.float64]: 234 x = constant_op.constant([1.0, -0.0], dtype=dtype) 235 236 with backprop.GradientTape() as tape: 237 tape.watch(x) 238 y = math_ops.reduce_euclidean_norm(x) 239 240 dx = tape.gradient(y, x) 241 dx_answer = constant_op.constant([1.0, -0.0], dtype=dtype) 242 self.assertAllClose(dx, dx_answer) 243 self.assertAllClose(1.0 / dx, 1.0 / dx_answer) 244 245 def testZeros(self): 246 for dtype in [dtypes.float32, dtypes.float64]: 247 x = constant_op.constant([0.0, -0.0], dtype=dtype) 248 249 with backprop.GradientTape() as tape: 250 tape.watch(x) 251 y = math_ops.reduce_euclidean_norm(x) 252 253 dx = tape.gradient(y, x) 254 dx_answer = constant_op.constant( 255 [float("NaN"), float("NaN")], dtype=dtype) 256 self.assertAllClose(dx, dx_answer) 257 258 def test2D_1(self): 259 for dtype in [dtypes.float32, dtypes.float64]: 260 x = constant_op.constant([[-3, 5], [7, 11]], dtype=dtype) 261 grads = gradient_checker_v2.compute_gradient( 262 math_ops.reduce_euclidean_norm, [x]) 263 err = gradient_checker_v2.max_error(*grads) 264 self.assertLess(err, 1e-3) 265 266 def test2D_2(self): 267 for dtype in [dtypes.float32, dtypes.float64]: 268 x = constant_op.constant([[-3, 5], [7, 11]], dtype=dtype) 269 grads = gradient_checker_v2.compute_gradient( 270 lambda x: math_ops.reduce_euclidean_norm(x, 0), [x]) 271 err = gradient_checker_v2.max_error(*grads) 272 self.assertLess(err, 1e-3) 273 274 def test2D_3(self): 275 for dtype in [dtypes.float32, dtypes.float64]: 276 x = constant_op.constant([[-3, 5], [7, 11]], dtype=dtype) 277 grads = gradient_checker_v2.compute_gradient( 278 lambda x: math_ops.reduce_euclidean_norm(x, 1), [x]) 279 err = gradient_checker_v2.max_error(*grads) 280 self.assertLess(err, 1e-3) 281 282 def test2D_4(self): 283 for dtype in [dtypes.float32, dtypes.float64]: 284 x = constant_op.constant([[3], [4]], dtype=dtype) 285 grads = gradient_checker_v2.compute_gradient( 286 lambda x: math_ops.reduce_euclidean_norm(x, 1), [x]) 287 err = gradient_checker_v2.max_error(*grads) 288 self.assertLess(err, 1e-3) 289 290 def test3D_1(self): 291 for dtype in [dtypes.float32, dtypes.float64]: 292 x = constant_op.constant([[[-3, 5], [7, 11]], [[13, 17], [19, 23]]], 293 dtype=dtype) 294 grads = gradient_checker_v2.compute_gradient( 295 math_ops.reduce_euclidean_norm, [x]) 296 err = gradient_checker_v2.max_error(*grads) 297 self.assertLess(err, 2e-3) 298 299 def test3D_2(self): 300 for dtype in [dtypes.float32, dtypes.float64]: 301 x = constant_op.constant([[[-3, 5], [7, 11]], [[13, 17], [19, 23]]], 302 dtype=dtype) 303 grads = gradient_checker_v2.compute_gradient( 304 lambda x: math_ops.reduce_euclidean_norm(x, 0), [x]) 305 err = gradient_checker_v2.max_error(*grads) 306 self.assertLess(err, 2e-3) 307 308 def test3D_3(self): 309 for dtype in [dtypes.float32, dtypes.float64]: 310 x = constant_op.constant([[[-3, 5], [7, 11]], [[13, 17], [19, 23]]], 311 dtype=dtype) 312 grads = gradient_checker_v2.compute_gradient( 313 lambda x: math_ops.reduce_euclidean_norm(x, 1), [x]) 314 err = gradient_checker_v2.max_error(*grads) 315 self.assertLess(err, 3e-3) 316 317 def test3D_4(self): 318 for dtype in [dtypes.float32, dtypes.float64]: 319 x = constant_op.constant([[[-3, 5], [7, 11]], [[13, 17], [19, 23]]], 320 dtype=dtype) 321 grads = gradient_checker_v2.compute_gradient( 322 lambda x: math_ops.reduce_euclidean_norm(x, 2), [x]) 323 err = gradient_checker_v2.max_error(*grads) 324 self.assertLess(err, 2e-3) 325 326 327class SegmentMinOrMaxGradientTest(test.TestCase): 328 329 @test_util.run_deprecated_v1 330 def testSegmentMinGradient(self): 331 data = constant_op.constant([1.0, 2.0, 3.0], dtype=dtypes.float32) 332 segment_ids = constant_op.constant([0, 0, 1], dtype=dtypes.int64) 333 segment_min = math_ops.segment_min(data, segment_ids) 334 with self.cached_session(): 335 error = gradient_checker.compute_gradient_error(data, [3], segment_min, 336 [2]) 337 self.assertLess(error, 1e-4) 338 339 @test_util.run_deprecated_v1 340 def testSegmentMaxGradient(self): 341 data = constant_op.constant([1.0, 2.0, 3.0], dtype=dtypes.float32) 342 segment_ids = constant_op.constant([0, 0, 1], dtype=dtypes.int64) 343 segment_max = math_ops.segment_max(data, segment_ids) 344 with self.cached_session(): 345 error = gradient_checker.compute_gradient_error(data, [3], segment_max, 346 [2]) 347 self.assertLess(error, 1e-4) 348 349 @test_util.run_deprecated_v1 350 def testSegmentMinGradientWithTies(self): 351 inputs = constant_op.constant([1.0], dtype=dtypes.float32) 352 data = array_ops.concat([inputs, inputs], 0) 353 segment_ids = constant_op.constant([0, 0], dtype=dtypes.int64) 354 segment_min = math_ops.segment_min(data, segment_ids) 355 with self.cached_session(): 356 error = gradient_checker.compute_gradient_error(inputs, [1], segment_min, 357 [1]) 358 self.assertLess(error, 1e-4) 359 360 @test_util.run_deprecated_v1 361 def testSegmentMaxGradientWithTies(self): 362 inputs = constant_op.constant([1.0], dtype=dtypes.float32) 363 data = array_ops.concat([inputs, inputs], 0) 364 segment_ids = constant_op.constant([0, 0], dtype=dtypes.int64) 365 segment_max = math_ops.segment_max(data, segment_ids) 366 with self.cached_session(): 367 error = gradient_checker.compute_gradient_error(inputs, [1], segment_max, 368 [1]) 369 self.assertLess(error, 1e-4) 370 371 372@test_util.run_all_in_graph_and_eager_modes 373class SegmentProdGradientTest(test.TestCase): 374 375 def _run_gradient_check(self, data, segment_ids): 376 377 def _segment_prod(x): 378 return math_ops.segment_prod(x, segment_ids) 379 380 err = gradient_checker_v2.max_error( 381 *gradient_checker_v2.compute_gradient(_segment_prod, [data])) 382 self.assertLess(err, 2e-4) 383 384 def testSegmentProdGradientWithoutOverlap(self): 385 data = constant_op.constant([[1, 2, 3, 4], [4, 3, 2, 1], [5, 6, 7, 8]], 386 dtype=dtypes.float32) 387 segment_ids = constant_op.constant([0, 1, 2], dtype=dtypes.int64) 388 self._run_gradient_check(data, segment_ids) 389 390 def testSegmentProdGradientWithoutZeros(self): 391 data = constant_op.constant([[1, 2, 3, 4], [4, 3, 2, 1], [5, 6, 7, 8]], 392 dtype=dtypes.float32) 393 segment_ids = constant_op.constant([0, 0, 1], dtype=dtypes.int64) 394 self._run_gradient_check(data, segment_ids) 395 396 def testSegmentProdGradientWithZeros(self): 397 data = constant_op.constant([[0, 2, 3, 4], [0, 0, 2, 0], [5, 0, 7, 0]], 398 dtype=dtypes.float32) 399 segment_ids = constant_op.constant([0, 0, 1], dtype=dtypes.int64) 400 self._run_gradient_check(data, segment_ids) 401 402 def testSegmentProdGradientWithEmptySegment(self): 403 data = constant_op.constant([[1, 2, 3, 4], [4, 3, 2, 1], [5, 6, 7, 8]], 404 dtype=dtypes.float32) 405 segment_ids = constant_op.constant([0, 0, 2], dtype=dtypes.int64) 406 self._run_gradient_check(data, segment_ids) 407 408 409class FloorModGradientTest(test.TestCase): 410 411 @test_util.run_deprecated_v1 412 def testFloorModGradient(self): 413 # Making sure the input is not near the discontinuity point where 414 # x/y == floor(x/y) 415 ns = constant_op.constant([17.], dtype=dtypes.float32) 416 inputs = constant_op.constant([131.], dtype=dtypes.float32) 417 floor_mod = math_ops.floormod(inputs, ns) 418 with self.cached_session(): 419 error = gradient_checker.compute_gradient_error(inputs, [1], 420 floor_mod, [1]) 421 self.assertLess(error, 1e-4) 422 423 424class DivNoNanGradientTest(test.TestCase): 425 426 @test_util.run_deprecated_v1 427 def testBasicGradient(self): 428 inputs = constant_op.constant(np.arange(-3, 3), 429 dtype=dtypes.float32) 430 outputs = math_ops.div_no_nan(inputs, 1 + math_ops.abs(inputs)) 431 with self.cached_session(): 432 error = gradient_checker.compute_gradient_error( 433 inputs, 434 inputs.get_shape().as_list(), outputs, 435 outputs.get_shape().as_list()) 436 self.assertLess(error, 1e-4) 437 438 @test_util.run_deprecated_v1 439 def testGradientWithDenominatorIsZero(self): 440 x = constant_op.constant(np.arange(-3, 3), 441 dtype=dtypes.float32) 442 y = array_ops.zeros_like(x, 443 dtype=dtypes.float32) 444 outputs = math_ops.div_no_nan(x, y) 445 with self.cached_session(): 446 dx, dy = gradients.gradients(outputs, [x, y]) 447 self.assertAllClose(dx, np.zeros(x.shape.as_list())) 448 self.assertAllClose(dy, np.zeros(y.shape.as_list())) 449 450 451class MulNoNanGradientTest(test.TestCase): 452 453 @test_util.run_deprecated_v1 454 def testBasicGradient(self): 455 inputs = constant_op.constant(np.arange(-3, 3), dtype=dtypes.float32) 456 outputs = math_ops.mul_no_nan(inputs, 1 + math_ops.abs(inputs)) 457 with self.cached_session(): 458 error = gradient_checker.compute_gradient_error( 459 inputs, 460 inputs.get_shape().as_list(), outputs, 461 outputs.get_shape().as_list()) 462 self.assertLess(error, 1e-4) 463 464 @test_util.run_deprecated_v1 465 def testGradientWithRhsIsZero(self): 466 x_vals = [0, 1.0, np.nan, np.inf, np.NINF] 467 x = constant_op.constant(x_vals, dtype=dtypes.float32) 468 y = array_ops.zeros_like(x, dtype=dtypes.float32) 469 outputs = math_ops.mul_no_nan(x, y) 470 with self.cached_session(): 471 dx, dy = gradients.gradients(outputs, [x, y]) 472 self.assertAllClose(dx, np.zeros(x.shape.as_list())) 473 self.assertAllClose(dy, x_vals) 474 475 476class XlogyTest(test.TestCase): 477 478 def _xlogy_gradients(self, x, y): 479 xlogy_xgrad = self.evaluate(gradients.gradients(math_ops.xlogy(x, y), x)[0]) 480 xlogy_ygrad = self.evaluate(gradients.gradients(math_ops.xlogy(x, y), y)[0]) 481 return xlogy_xgrad, xlogy_ygrad 482 483 @test_util.run_deprecated_v1 484 def testNonZeroValuesGrad(self): 485 for dtype in [dtypes.float16, dtypes.float32, dtypes.float64]: 486 x = constant_op.constant(0.1, dtype=dtype) 487 y = constant_op.constant(3.1, dtype=dtype) 488 xlogy_xgrad, xlogy_ygrad = self._xlogy_gradients(x, y) 489 xlogy_expected_xgrad = self.evaluate(math_ops.log(y)) 490 xlogy_expected_ygrad = self.evaluate(x / y) 491 self.assertAllClose(xlogy_expected_xgrad, xlogy_xgrad) 492 self.assertAllClose(xlogy_expected_ygrad, xlogy_ygrad) 493 494 @test_util.run_deprecated_v1 495 def testZeroXGrad(self): 496 for dtype in [dtypes.float16, dtypes.float32, dtypes.float64]: 497 x = constant_op.constant(0., dtype=dtype) 498 y = constant_op.constant(3.1, dtype=dtype) 499 xlogy_xgrad, xlogy_ygrad = self._xlogy_gradients(x, y) 500 zero = self.evaluate(x) 501 self.assertAllClose(zero, xlogy_xgrad) 502 self.assertAllClose(zero, xlogy_ygrad) 503 504 @test_util.run_deprecated_v1 505 def testZeroYGrad(self): 506 for dtype in [dtypes.float16, dtypes.float32, dtypes.float64]: 507 x = constant_op.constant(0.1, dtype=dtype) 508 y = constant_op.constant(0., dtype=dtype) 509 xlogy_xgrad, xlogy_ygrad = self._xlogy_gradients(x, y) 510 self.assertAllClose(-np.inf, xlogy_xgrad) 511 self.assertAllClose(np.inf, xlogy_ygrad) 512 513 @test_util.run_deprecated_v1 514 def testZeroXYGrad(self): 515 for dtype in [dtypes.float16, dtypes.float32, dtypes.float64]: 516 x = constant_op.constant(0., dtype=dtype) 517 y = constant_op.constant(0., dtype=dtype) 518 xlogy_xgrad, xlogy_ygrad = self._xlogy_gradients(x, y) 519 zero = self.evaluate(x) 520 self.assertAllClose(zero, xlogy_xgrad) 521 self.assertAllClose(zero, xlogy_ygrad) 522 523 524class Xlog1pyTest(test.TestCase): 525 526 def _xlog1py_gradients(self, x, y): 527 xlog1py_xgrad = self.evaluate( 528 gradients.gradients(math_ops.xlog1py(x, y), x)[0]) 529 xlog1py_ygrad = self.evaluate( 530 gradients.gradients(math_ops.xlog1py(x, y), y)[0]) 531 return xlog1py_xgrad, xlog1py_ygrad 532 533 @test_util.run_deprecated_v1 534 def testNonZeroValuesGrad(self): 535 for dtype in [dtypes.float16, dtypes.float32, dtypes.float64]: 536 x = constant_op.constant(0.1, dtype=dtype) 537 y = constant_op.constant(3.1, dtype=dtype) 538 xlog1py_xgrad, xlog1py_ygrad = self._xlog1py_gradients(x, y) 539 xlog1py_expected_xgrad = self.evaluate(math_ops.log1p(y)) 540 xlog1py_expected_ygrad = self.evaluate(x / (1. + y)) 541 self.assertAllClose(xlog1py_expected_xgrad, xlog1py_xgrad) 542 self.assertAllClose(xlog1py_expected_ygrad, xlog1py_ygrad) 543 544 @test_util.run_deprecated_v1 545 def testZeroXGrad(self): 546 for dtype in [dtypes.float16, dtypes.float32, dtypes.float64]: 547 x = constant_op.constant(0., dtype=dtype) 548 y = constant_op.constant(3.1, dtype=dtype) 549 xlog1py_xgrad, xlog1py_ygrad = self._xlog1py_gradients(x, y) 550 zero = self.evaluate(x) 551 self.assertAllClose(zero, xlog1py_xgrad) 552 self.assertAllClose(zero, xlog1py_ygrad) 553 554 @test_util.run_deprecated_v1 555 def testNegOneYGrad(self): 556 for dtype in [dtypes.float16, dtypes.float32, dtypes.float64]: 557 x = constant_op.constant(0.1, dtype=dtype) 558 y = constant_op.constant(-1., dtype=dtype) 559 xlog1py_xgrad, xlog1py_ygrad = self._xlog1py_gradients(x, y) 560 self.assertAllClose(-np.inf, xlog1py_xgrad) 561 self.assertAllClose(np.inf, xlog1py_ygrad) 562 563 @test_util.run_deprecated_v1 564 def testZeroXNegOneYGrad(self): 565 for dtype in [dtypes.float16, dtypes.float32, dtypes.float64]: 566 x = constant_op.constant(0., dtype=dtype) 567 y = constant_op.constant(-1., dtype=dtype) 568 xlog1py_xgrad, xlog1py_ygrad = self._xlog1py_gradients(x, y) 569 zero = self.evaluate(x) 570 self.assertAllClose(zero, xlog1py_xgrad) 571 self.assertAllClose(zero, xlog1py_ygrad) 572 573 574class XdivyTest(test.TestCase): 575 576 def _xdivy_gradients(self, x, y): 577 xdivy_xgrad = self.evaluate(gradients.gradients(math_ops.xdivy(x, y), x)[0]) 578 xdivy_ygrad = self.evaluate(gradients.gradients(math_ops.xdivy(x, y), y)[0]) 579 return xdivy_xgrad, xdivy_ygrad 580 581 @test_util.run_deprecated_v1 582 def testNonZeroValuesGrad(self): 583 for dtype in [dtypes.float16, dtypes.float32, dtypes.float64]: 584 x = constant_op.constant(0.1, dtype=dtype) 585 y = constant_op.constant(3.1, dtype=dtype) 586 xdivy_xgrad, xdivy_ygrad = self._xdivy_gradients(x, y) 587 xdivy_expected_xgrad = self.evaluate(1 / y) 588 xdivy_expected_ygrad = self.evaluate(-x / y**2) 589 self.assertAllClose(xdivy_expected_xgrad, xdivy_xgrad) 590 self.assertAllClose(xdivy_expected_ygrad, xdivy_ygrad) 591 592 @test_util.run_deprecated_v1 593 def testZeroXGrad(self): 594 for dtype in [dtypes.float16, dtypes.float32, dtypes.float64]: 595 x = constant_op.constant(0., dtype=dtype) 596 y = constant_op.constant(3.1, dtype=dtype) 597 xdivy_xgrad, xdivy_ygrad = self._xdivy_gradients(x, y) 598 zero = self.evaluate(x) 599 self.assertAllClose(zero, xdivy_xgrad) 600 self.assertAllClose(zero, xdivy_ygrad) 601 602 @test_util.run_deprecated_v1 603 def testZeroYGrad(self): 604 for dtype in [dtypes.float16, dtypes.float32, dtypes.float64]: 605 x = constant_op.constant(0.1, dtype=dtype) 606 y = constant_op.constant(0., dtype=dtype) 607 xdivy_xgrad, xdivy_ygrad = self._xdivy_gradients(x, y) 608 self.assertAllClose(np.inf, xdivy_xgrad) 609 self.assertAllClose(-np.inf, xdivy_ygrad) 610 611 @test_util.run_deprecated_v1 612 def testZeroXYGrad(self): 613 for dtype in [dtypes.float16, dtypes.float32, dtypes.float64]: 614 x = constant_op.constant(0., dtype=dtype) 615 y = constant_op.constant(0., dtype=dtype) 616 xdivy_xgrad, xdivy_ygrad = self._xdivy_gradients(x, y) 617 zero = self.evaluate(x) 618 self.assertAllClose(zero, xdivy_xgrad) 619 self.assertAllClose(zero, xdivy_ygrad) 620 621 622@test_util.run_all_in_graph_and_eager_modes 623class PowGradTest(test.TestCase): 624 625 def test_zero_grad_tf_gradients(self): 626 if context.executing_eagerly(): 627 self.skipTest("tf.gradients not supported in eager.") 628 629 x = constant_op.constant([-1., 0., 1.]) 630 g = self.evaluate(gradients.gradients(math_ops.pow(x, 2), x)[0]) 631 self.assertAllClose([-2., 0., 2.], g) 632 633 def test_zero_grad_tape(self): 634 x = constant_op.constant([-1, 0., 1.]) 635 with backprop.GradientTape() as tape: 636 tape.watch(x) 637 g = tape.gradient(math_ops.pow(x, 2), x) 638 g = self.evaluate(g) 639 self.assertAllClose([-2., 0., 2.], g) 640 641 642@test_util.run_all_in_graph_and_eager_modes 643class NextAfterTest(test.TestCase): 644 645 def _nextafter_gradient(self, x1, x2): 646 with backprop.GradientTape() as tape: 647 tape.watch(x1) 648 tape.watch(x2) 649 y = math_ops.nextafter(x1, x2) 650 return tape.gradient(y, [x1, x2]) 651 652 def testBasic(self): 653 for dtype in [dtypes.float32, dtypes.float64]: 654 x1 = constant_op.constant(0.1, dtype=dtype) 655 x2 = constant_op.constant(3.1, dtype=dtype) 656 dx1, dx2 = self._nextafter_gradient(x1, x2) 657 expected_dx1 = constant_op.constant(1, dtype=dtype) 658 expected_dx2 = constant_op.constant(0, dtype=dtype) 659 self.assertAllClose(expected_dx1, dx1) 660 self.assertAllClose(expected_dx2, dx2) 661 662 def testDynamicShapes(self): 663 for dtype in [dtypes.float32, dtypes.float64]: 664 default_x1 = constant_op.constant(0.1, dtype=dtype) 665 default_x2 = constant_op.constant(3.1, dtype=dtype) 666 x1 = array_ops.placeholder_with_default(default_x1, shape=None) 667 x2 = array_ops.placeholder_with_default(default_x2, shape=None) 668 dx1, dx2 = self._nextafter_gradient(x1, x2) 669 expected_dx1 = constant_op.constant(1, dtype=dtype) 670 expected_dx2 = constant_op.constant(0, dtype=dtype) 671 self.assertAllClose(expected_dx1, dx1) 672 self.assertAllClose(expected_dx2, dx2) 673 674 def testWithGradientChecker(self): 675 for dtype in [dtypes.float32, dtypes.float64]: 676 with self.cached_session(): 677 x1 = np.array([-1, 0, 1, 2, 3], dtype=dtype.as_numpy_dtype) 678 x2 = np.array([2, 2, 2, 2, 2], dtype=dtype.as_numpy_dtype) 679 err = gradient_checker_v2.max_error( 680 *gradient_checker_v2.compute_gradient( 681 lambda x: math_ops.nextafter(x, x2), [x1])) # pylint: disable=cell-var-from-loop 682 self.assertLess(err, 1e-3) 683 684 def testBroadcastingWithGradientChecker(self): 685 for dtype in [dtypes.float32, dtypes.float64]: 686 with self.cached_session(): 687 x1 = np.array([-1, 0, 1, 2, 3], dtype=dtype.as_numpy_dtype) 688 x2 = np.array([2], dtype=dtype.as_numpy_dtype) 689 err = gradient_checker_v2.max_error( 690 *gradient_checker_v2.compute_gradient( 691 lambda x: math_ops.nextafter(x, x2), [x1])) # pylint: disable=cell-var-from-loop 692 self.assertLess(err, 1e-3) 693 694 695if __name__ == "__main__": 696 test.main() 697