1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for tensorflow.ops.gradients.""" 16 17import sys 18import warnings 19 20from absl.testing import parameterized 21import numpy as np 22from tensorflow.python.client import session 23from tensorflow.python.eager import backprop 24from tensorflow.python.eager import context 25from tensorflow.python.eager import function 26from tensorflow.python.framework import constant_op 27from tensorflow.python.framework import dtypes 28from tensorflow.python.framework import function as framework_function 29from tensorflow.python.framework import indexed_slices 30from tensorflow.python.framework import ops 31from tensorflow.python.framework import tensor_spec 32from tensorflow.python.framework import test_ops 33from tensorflow.python.framework import test_util 34from tensorflow.python.framework.constant_op import constant 35from tensorflow.python.layers import core as core_layers 36from tensorflow.python.ops import array_grad # pylint: disable=unused-import 37from tensorflow.python.ops import array_ops 38from tensorflow.python.ops import control_flow_grad # pylint: disable=unused-import 39from tensorflow.python.ops import control_flow_ops 40from tensorflow.python.ops import custom_gradient 41from tensorflow.python.ops import data_flow_grad # pylint: disable=unused-import 42from tensorflow.python.ops import data_flow_ops # pylint: disable=unused-import 43from tensorflow.python.ops import functional_ops # pylint: disable=unused-import 44from tensorflow.python.ops import gradient_checker_v2 45from tensorflow.python.ops import gradients 46from tensorflow.python.ops import gradients_impl 47from tensorflow.python.ops import init_ops 48from tensorflow.python.ops import list_ops 49from tensorflow.python.ops import math_grad # pylint: disable=unused-import 50from tensorflow.python.ops import math_ops 51from tensorflow.python.ops import nn_grad # pylint: disable=unused-import 52from tensorflow.python.ops import resource_variable_ops 53from tensorflow.python.ops import state_grad # pylint: disable=unused-import 54from tensorflow.python.ops import state_ops 55from tensorflow.python.ops import tensor_array_grad # pylint: disable=unused-import 56from tensorflow.python.ops import tensor_array_ops 57from tensorflow.python.ops import unconnected_gradients 58from tensorflow.python.ops import variable_scope 59from tensorflow.python.ops import variables 60from tensorflow.python.ops.nn_ops import bias_add 61from tensorflow.python.platform import googletest 62from tensorflow.python.util import nest 63 64 65class GradientsTest(test_util.TensorFlowTestCase, parameterized.TestCase): 66 67 def testGradients(self): 68 with ops.Graph().as_default(): 69 inp = constant(1.0, shape=[32, 100], name="in") 70 w = constant(1.0, shape=[100, 10], name="w") 71 b = constant(1.0, shape=[10], name="b") 72 xw = math_ops.matmul(inp, w, name="xw") 73 h = bias_add(xw, b, name="h") 74 w_grad = gradients.gradients(h, w)[0] 75 self.assertEqual("MatMul", w_grad.op.type) 76 self.assertEqual(w_grad.op._original_op, xw.op) 77 self.assertTrue(w_grad.op.get_attr("transpose_a")) 78 self.assertFalse(w_grad.op.get_attr("transpose_b")) 79 80 def testUnusedOutput(self): 81 with ops.Graph().as_default(): 82 w = constant(1.0, shape=[2, 2]) 83 x = constant(1.0, shape=[2, 2]) 84 wx = math_ops.matmul(w, x) 85 split_wx = array_ops.split(value=wx, num_or_size_splits=2, axis=0) 86 c = math_ops.reduce_sum(split_wx[1]) 87 gw = gradients.gradients(c, [w])[0] 88 self.assertEqual("MatMul", gw.op.type) 89 90 def testColocateGradients(self): 91 with ops.Graph().as_default() as g: 92 w = constant(1.0, shape=[1, 1]) 93 x = constant(1.0, shape=[1, 2]) 94 with g.device("/device:GPU:0"): 95 wx = math_ops.matmul(w, x) 96 gw = gradients.gradients(wx, [w], colocate_gradients_with_ops=True)[0] 97 self.assertEqual(gw.op.colocation_groups(), wx.op.colocation_groups()) 98 99 def testColocateGradientsWithAggregation(self): 100 with ops.Graph().as_default() as g: 101 with g.device("/device:GPU:1"): 102 w = constant(1.0, shape=[1, 1]) 103 x = constant(1.0, shape=[1, 2]) 104 y = constant(1.0, shape=[1, 2]) 105 wx = math_ops.matmul(w, x) 106 wy = math_ops.matmul(w, y) 107 with g.device("/device:GPU:0"): 108 z = wx + wy 109 110 gw1 = gradients.gradients(z, [w], colocate_gradients_with_ops=True)[0] 111 self.assertEqual(gw1.op.colocation_groups(), wx.op.colocation_groups()) 112 113 gw2 = gradients.gradients(z, [w], colocate_gradients_with_ops=False)[0] 114 self.assertNotEqual(wx.op.colocation_groups(), gw2.op.colocation_groups()) 115 116 def testColocateGradientsWithAggregationInMultipleDevices(self): 117 with ops.Graph().as_default() as g: 118 with g.device("/device:GPU:1"): 119 w = constant(1.0, shape=[1, 1]) 120 x = constant(1.0, shape=[1, 2]) 121 y = constant(1.0, shape=[1, 2]) 122 with g.device("/task:1"): 123 wx = math_ops.matmul(w, x) 124 with g.device("/task:2"): 125 wy = math_ops.matmul(w, y) 126 with g.device("/device:GPU:0"): 127 z = wx + wy 128 129 gw1 = gradients.gradients(z, [w], colocate_gradients_with_ops=True)[0] 130 self.assertEqual(gw1.op.colocation_groups(), w.op.colocation_groups()) 131 132 gw2 = gradients.gradients(z, [w], colocate_gradients_with_ops=False)[0] 133 self.assertNotEqual(w.op.colocation_groups(), gw2.op.colocation_groups()) 134 135 def testColocateGradientsWithGateGradients(self): 136 if not test_util.is_gpu_available(): 137 self.skipTest("No GPU available") 138 with ops.Graph().as_default() as g: 139 with g.device("/device:CPU:0"): 140 x = constant(1.0, shape=[1, 1]) 141 y = constant(1.0, shape=[1, 1]) 142 s = x + y 143 with g.device("/device:GPU:0"): 144 z = math_ops.reduce_sum(s) 145 146 gz_x = gradients.gradients(z, [x], colocate_gradients_with_ops=True, 147 gate_gradients=True)[0] 148 149 # Make sure the placer doesn't complain. 150 self.evaluate(gz_x) 151 152 def testBoundaryStop(self): 153 # Test that we don't differentiate 'x'. The gradient function for 'x' is 154 # set explicitly to None so we will get an exception if the gradient code 155 # tries to differentiate 'x'. 156 with ops.Graph().as_default(): 157 c = constant(1.0) 158 x = array_ops.identity(c) 159 y = x + 1.0 160 z = y + 1 161 grads = gradients.gradients(z, [x]) 162 self.assertTrue(all(x is not None for x in grads)) 163 164 @test_util.run_v1_only("b/120545219") 165 def testBoundaryContinue(self): 166 # Test that we differentiate both 'x' and 'y' correctly when x is a 167 # predecessor of y. 168 with self.cached_session(): 169 x = constant(1.0) 170 y = x * 2.0 171 z = y * 3.0 172 grads = gradients.gradients(z, [x, y]) 173 self.assertTrue(all(x is not None for x in grads)) 174 self.assertEqual(6.0, grads[0].eval()) 175 176 @test_util.run_v1_only("b/120545219") 177 def testAggregationMethodAccumulateN(self): 178 with self.cached_session(): 179 x = constant(1.0) 180 y = x * 2.0 181 z = y + y + y + y + y + y + y + y + y + y 182 grads = gradients.gradients( 183 z, [x, y], 184 aggregation_method=gradients.AggregationMethod. 185 EXPERIMENTAL_ACCUMULATE_N) 186 self.assertTrue(all(x is not None for x in grads)) 187 self.assertEqual(20.0, grads[0].eval()) 188 self.assertEqual(10.0, grads[1].eval()) 189 190 @test_util.run_v1_only("b/120545219") 191 def testAggregationMethodAddN(self): 192 with self.cached_session(): 193 x = constant(1.0) 194 y = x * 2.0 195 z = y + y + y + y + y + y + y + y + y + y 196 grads = gradients.gradients( 197 z, [x, y], aggregation_method=gradients.AggregationMethod.ADD_N) 198 self.assertTrue(all(x is not None for x in grads)) 199 self.assertEqual(20.0, grads[0].eval()) 200 self.assertEqual(10.0, grads[1].eval()) 201 202 @test_util.run_v1_only("b/120545219") 203 def testAggregationMethodTree(self): 204 with self.cached_session(): 205 x = constant(1.0) 206 y = x * 2.0 207 z = y + y + y + y + y + y + y + y + y + y 208 grads = gradients.gradients( 209 z, [x, y], 210 aggregation_method=gradients.AggregationMethod.EXPERIMENTAL_TREE) 211 self.assertTrue(all(x is not None for x in grads)) 212 self.assertEqual(20.0, grads[0].eval()) 213 self.assertEqual(10.0, grads[1].eval()) 214 215 def testNoGradientForStringOutputs(self): 216 with ops.Graph().as_default(): 217 218 def _TestOpGrad(_, float_grad, string_grad): 219 """Gradient function for TestStringOutput.""" 220 self.assertEqual(float_grad.dtype, dtypes.float32) 221 self.assertFalse(string_grad) 222 return float_grad 223 224 ops.RegisterGradient("TestStringOutput")(_TestOpGrad) 225 226 c = constant(1.0) 227 x, _ = test_ops.test_string_output(c) 228 z = x * 2.0 229 w = z * 3.0 230 grads = gradients.gradients(z, [c]) 231 self.assertIsInstance(grads[0], ops.Tensor) 232 grads = gradients.gradients(w, [c]) 233 self.assertIsInstance(grads[0], ops.Tensor) 234 235 def testNoGradientForStringOutputsWithOpNamespace(self): 236 with ops.Graph().as_default(): 237 238 def _TestOpGrad(_, float_grad, string_grad): 239 """Gradient function for TestStringOutput.""" 240 self.assertEqual(float_grad.dtype, dtypes.float32) 241 self.assertFalse(string_grad) 242 return float_grad 243 244 ops.RegisterGradient("Namespace>TestStringOutput")(_TestOpGrad) 245 246 c = constant(1.0) 247 x, _ = test_ops.namespace_test_string_output(c) 248 z = x * 2.0 249 w = z * 3.0 250 grads = gradients.gradients(z, [c]) 251 self.assertIsInstance(grads[0], ops.Tensor) 252 grads = gradients.gradients(w, [c]) 253 self.assertIsInstance(grads[0], ops.Tensor) 254 255 def testSingletonIndexedSlices(self): 256 with ops.Graph().as_default(): 257 x = array_ops.placeholder(dtypes.float32) 258 y = array_ops.identity(x) 259 dy = indexed_slices.IndexedSlices( 260 array_ops.placeholder(dtypes.float32), 261 array_ops.placeholder(dtypes.int32)) 262 dx, = gradients.gradients(y, x, grad_ys=dy) 263 # The IndexedSlices gradient of tf.identity is the identity map. 264 with self.cached_session() as sess: 265 vdx, vdy = sess.run( 266 [dx, dy], feed_dict={x: [1.0], dy.indices: [0], dy.values: [2.0]}) 267 self.assertEqual(vdx, vdy) 268 269 @test_util.run_v1_only("b/120545219") 270 def testNonDifferentiableSwitchInWhileLoop(self): 271 with ops.Graph().as_default(): 272 v = array_ops.placeholder(dtypes.float32, []) 273 274 def _Step(i, a, ta): 275 a += math_ops.cast(v, dtypes.int32) 276 return (i + 1, a, ta.write(i, a)) 277 278 n = 4 279 i, _, ta = control_flow_ops.while_loop( 280 lambda i, *_: i < n, 281 _Step, [0, 0, tensor_array_ops.TensorArray( 282 dtypes.int32, size=n)]) 283 target = ta.read(i - 1) 284 grad, = gradients.gradients(target, v) 285 self.assertIsNone(grad) 286 287 def testVariableReadValueGradient(self): 288 with ops.Graph().as_default(): 289 init = constant_op.constant(100.0) 290 var = variables.Variable(init) 291 gradient = gradients.gradients(var.read_value(), var) 292 self.assertIsNotNone(gradient) 293 294 @parameterized.parameters(dtypes.float32, dtypes.float64) 295 def testVariableDefaultGrad(self, dtype): 296 with ops.Graph().as_default(): 297 init = constant_op.constant(100.0, dtype=dtype) 298 var = variables.Variable(init) 299 dummy_const = constant_op.constant(0.0) 300 gradient = gradients.gradients( 301 dummy_const, 302 var, 303 unconnected_gradients=unconnected_gradients.UnconnectedGradients.ZERO 304 )[0] 305 self.assertEqual(gradient.dtype, dtype) 306 self.assertIsNotNone(gradient) 307 308 def testVariableAsGraphElementGradient(self): 309 with ops.Graph().as_default() as graph: 310 init = constant_op.constant(100.0) 311 var = variables.Variable(init) 312 gradient = gradients.gradients(graph.as_graph_element(var), var) 313 self.assertIsNotNone(gradient) 314 315 @test_util.run_v1_only("b/120545219") 316 def testVariableRefGradient(self): 317 with ops.Graph().as_default(): 318 init = constant_op.constant(100.0) 319 var = variables.VariableV1(init) 320 gradient = gradients.gradients(var._ref(), var) 321 self.assertIsNotNone(gradient) 322 323 @test_util.run_v1_only("b/120545219") 324 def testDependentYs(self): 325 with self.cached_session(): 326 x = constant_op.constant(3.0) 327 y = math_ops.square(x) 328 y1 = math_ops.square(y) 329 y2 = math_ops.square(y1) 330 g = gradients.gradients([y, y2], x) 331 self.assertAllClose(17502.0, g[0]) 332 g = gradients.gradients(y + y2, x) 333 self.assertAllClose(17502.0, g[0]) 334 z = array_ops.identity(y) 335 z2 = array_ops.identity(y2) 336 g = gradients.gradients([z, z2], x) 337 self.assertAllClose(17502.0, g[0]) 338 339 @test_util.run_v1_only("b/120545219") 340 def testPartialDerivatives(self): 341 with self.cached_session(): 342 x = constant_op.constant(1.) 343 y = 2 * x 344 z = x + y 345 totalg = gradients.gradients(z, [x, y]) 346 self.assertEqual([3.0, 1.0], [g.eval() for g in totalg]) 347 partialg = gradients.gradients(z, [x, y], stop_gradients=[x, y]) 348 self.assertEqual([1.0, 1.0], [g.eval() for g in partialg]) 349 350 @test_util.run_v1_only("b/120545219") 351 def testStopGradients(self): 352 def _MakeGraph(rng, stop_gradients=()): 353 def _FunctionOf(xs, k=3): 354 return ops.convert_to_tensor( 355 sum(math_ops.matmul(rng.rand(k, k), x) for x in xs) 356 + rng.rand(k, k)) 357 358 a = _FunctionOf([]) 359 if "a" in stop_gradients: a = array_ops.stop_gradient(a) 360 b = _FunctionOf([a]) 361 if "b" in stop_gradients: b = array_ops.stop_gradient(b) 362 c = _FunctionOf([a, b]) 363 if "c" in stop_gradients: c = array_ops.stop_gradient(c) 364 d = _FunctionOf([b, c]) 365 if "d" in stop_gradients: d = array_ops.stop_gradient(d) 366 return dict(a=a, b=b, c=c, d=d) 367 368 def _Gradients(ys, xs, **kwargs): 369 dydxs = gradients.gradients(ys, xs, **kwargs) 370 dydxs = [0. * x if dydx is None else dydx 371 for x, dydx in zip(xs, dydxs)] 372 return dydxs 373 374 seed = np.random.randint(1000) 375 cases = [] 376 subsets = [""] + "a b c d ab ac ad bc bd cd abc abd acd bcd abcd".split() 377 graph = _MakeGraph(np.random.RandomState(seed)) 378 for constants in subsets: 379 graph_with_stops = _MakeGraph(np.random.RandomState(seed), constants) 380 for variables_ in subsets: 381 # compute the gradient when stopped using tf.stop_gradients 382 grad1 = _Gradients([graph_with_stops["d"]], 383 [graph_with_stops[v] for v in variables_]) 384 # compute the gradient when stopped using the stop_gradients kwarg 385 grad2 = _Gradients([graph["d"]], 386 [graph[v] for v in variables_], 387 stop_gradients=[graph[v] for v in constants]) 388 cases.append(dict(grad1=grad1, grad2=grad2, 389 constants=constants, variables=variables_)) 390 391 # evaluate all tensors in one call to session.run for speed 392 with self.cached_session() as sess: 393 results = sess.run([(case["grad1"], case["grad2"]) for case in cases]) 394 395 for (npgrad1, npgrad2), case in zip(results, cases): 396 for a, b in zip(npgrad1, npgrad2): 397 np.testing.assert_allclose(a, b) 398 399 def testUnconnectedGradientsNoneUnconnectedGradients(self): 400 with ops.Graph().as_default(): 401 x = constant(1.0, shape=[2, 2]) 402 y = constant(3.0, shape=[3, 1]) 403 grad = gradients.gradients( 404 [y], [x], unconnected_gradients="none") 405 self.assertIsNone(grad[0]) 406 407 def testUnconnectedGradientsZerosUnconnectedGradients(self): 408 with ops.Graph().as_default(): 409 x = constant(1.0, shape=[2, 2]) 410 y = constant(3.0, shape=[3, 1]) 411 grads = gradients.gradients( 412 [y], [x], unconnected_gradients="zero") 413 414 self.assertAllEqual([[0.0, 0.0], [0.0, 0.0]], self.evaluate(grads)[0]) 415 416 def testUnconnectedGradientsZeroConnectedGradients(self): 417 with ops.Graph().as_default(): 418 x = constant(1.0) 419 y = x * 3.0 420 grad = gradients.gradients( 421 [y], [x], unconnected_gradients="zero") 422 423 self.assertEqual(3.0, self.evaluate(grad)[0]) 424 425 def testUnknownUnconnectedGradientsValueGiven(self): 426 with ops.Graph().as_default(): 427 x = constant(1.0) 428 y = constant(1.0) 429 with self.assertRaisesRegex( 430 ValueError, "Unknown value for unconnected_gradients: 'nonsense'"): 431 gradients.gradients([y], [x], unconnected_gradients="nonsense") 432 433 @parameterized.parameters(unconnected_gradients.UnconnectedGradients.ZERO, 434 unconnected_gradients.UnconnectedGradients.NONE) 435 def testUnconnectedOpWithMultipleOutputs(self, unconnected_gradients_val): 436 with ops.Graph().as_default(): 437 # a b 438 # | | 439 # IdentityN 440 # | | 441 # c d 442 # | 443 # Identity 444 # | 445 # e 446 a = constant_op.constant(1.0) 447 b = constant_op.constant(1.0) 448 c, d = array_ops.identity_n([a, b]) 449 e = array_ops.identity(c) 450 # The aggregated grads for the IdentityN node would look like 451 # [Tensor, None]. We expect this None to be converted to zeros. 452 output = gradients.gradients( 453 e, d, unconnected_gradients=unconnected_gradients_val) 454 if (unconnected_gradients_val == 455 unconnected_gradients.UnconnectedGradients.ZERO): 456 self.assertIsNotNone(output[0]) 457 else: 458 self.assertIsNone(output[0]) 459 460 @parameterized.parameters(unconnected_gradients.UnconnectedGradients.ZERO, 461 unconnected_gradients.UnconnectedGradients.NONE) 462 def testUnconnectedOpWithMultipleOutputsStopGradient( 463 self, unconnected_gradients_val): 464 with ops.Graph().as_default(): 465 # a b 466 # | | 467 # IdentityN 468 # | | 469 # c d 470 # | | 471 # SG | 472 # | | 473 # \ / 474 # + 475 # e 476 a = constant_op.constant(1.0) 477 b = constant_op.constant(1.0) 478 c, d = array_ops.identity_n([a, b]) 479 e = array_ops.stop_gradient(c) + d 480 # The aggregated grads for the IdentityN node would look like 481 # [None, Tensor]. We expect this None to be converted to zeros. 482 output = gradients.gradients( 483 e, c, unconnected_gradients=unconnected_gradients_val) 484 if (unconnected_gradients_val == 485 unconnected_gradients.UnconnectedGradients.ZERO): 486 self.assertIsNotNone(output[0]) 487 else: 488 self.assertIsNone(output[0]) 489 490 491class FunctionGradientsTest(test_util.TensorFlowTestCase): 492 493 @classmethod 494 def XSquarePlusB(cls, x, b): 495 return x * x + b 496 497 @classmethod 498 def XSquarePlusBGradient(cls, x, b, g): 499 # Perturb gradients (multiply by 2), so we can test that this was called. 500 g *= 2.0 501 return g * 2.0 * x, g 502 503 @classmethod 504 def _PythonGradient(cls, op, grad): 505 # Perturb gradients (multiply by 3), so we can test that this was called. 506 grad *= 3.0 507 return grad * op.inputs[0] * 2.0, grad 508 509 @classmethod 510 def _GetFunc(cls, **kwargs): 511 return framework_function.Defun(dtypes.float32, dtypes.float32, ** 512 kwargs)(cls.XSquarePlusB) 513 514 def _GetFuncGradients(self, f, x_value, b_value): 515 x = constant_op.constant(x_value, name="x") 516 b = constant_op.constant(b_value, name="b") 517 518 y = f(x, b) 519 grads = gradients.gradients(y, [x, b]) 520 521 return self.evaluate(grads) 522 523 def testFunctionGradientsBasic(self): 524 g = ops.Graph() 525 with g.as_default(): 526 f = self._GetFunc() 527 # Get gradients (should add SymbolicGradient node for function). 528 grads = self._GetFuncGradients(f, [2.0], [1.0]) 529 self.assertAllEqual([4.0], grads[0]) 530 self.assertAllEqual([1.0], grads[1]) 531 532 def testFunctionGradientsComposition(self): 533 with ops.Graph().as_default(): 534 f = self._GetFunc() 535 x = constant_op.constant([2.0], name="x") 536 b1 = constant_op.constant([1.0], name="b1") 537 b2 = constant_op.constant([1.0], name="b2") 538 539 y = f(f(x, b1), b2) 540 # Build gradient graph (should add SymbolicGradient node for function). 541 grads = gradients.gradients(y, [x, b1]) 542 543 self.assertAllEqual([40.0], self.evaluate(grads)[0]) 544 self.assertAllEqual([10.0], self.evaluate(grads)[1]) 545 546 def testFunctionGradientsWithGradFunc(self): 547 g = ops.Graph() 548 with g.as_default(): 549 grad_func = framework_function.Defun(dtypes.float32, dtypes.float32, 550 dtypes.float32)( 551 self.XSquarePlusBGradient) 552 f = self._GetFunc(grad_func=grad_func) 553 # Get gradients (should add SymbolicGradient node for function, which 554 # uses the grad_func above, which multiplies all gradients by 2). 555 grads = self._GetFuncGradients(f, [2.0], [1.0]) 556 self.assertAllEqual([4.0 * 2], grads[0]) 557 self.assertAllEqual([1.0 * 2], grads[1]) 558 559 def testFunctionGradientWithRegistration(self): 560 g = ops.Graph() 561 with g.as_default(): 562 f = self._GetFunc(python_grad_func=self._PythonGradient) 563 # Get gradients, using the python gradient function. It multiplies the 564 # gradients by 3. 565 grads = self._GetFuncGradients(f, [2.0], [1.0]) 566 self.assertAllEqual([4.0 * 3], grads[0]) 567 self.assertAllEqual([1.0 * 3], grads[1]) 568 569 def testFunctionGradientWithGradFuncAndRegistration(self): 570 g = ops.Graph() 571 with g.as_default(): 572 grad_func = framework_function.Defun(dtypes.float32, dtypes.float32, 573 dtypes.float32)( 574 self.XSquarePlusBGradient) 575 with self.assertRaisesRegex(ValueError, "Gradient defined twice"): 576 f = self._GetFunc( 577 grad_func=grad_func, python_grad_func=self._PythonGradient) 578 f.add_to_graph(ops.Graph()) 579 580 def testGradientWrtCaptured(self): 581 with ops.Graph().as_default(): 582 x = constant_op.constant(1.0, name="x") 583 584 @function.defun() 585 def Foo(): 586 y = math_ops.multiply(x, 2.0, name="y") 587 g = gradients_impl.gradients(y, x) 588 return g[0] 589 590 f = Foo() 591 592 self.assertEqual(self.evaluate(f), 2.0) 593 594 def testGradientOfCaptured(self): 595 with ops.Graph().as_default(): 596 x = constant_op.constant(1.0, name="x") 597 y = math_ops.multiply(x, 2.0, name="y") 598 599 @framework_function.Defun() 600 def Foo(): 601 g = gradients_impl.gradients(y, x) 602 return g[0] 603 604 f = Foo() 605 606 self.assertEqual(self.evaluate(f), 2.0) 607 608 def testCapturedResourceVariable(self): 609 with ops.Graph().as_default(): 610 var = resource_variable_ops.ResourceVariable(1.0, name="var") 611 612 @function.defun() 613 def Foo(): 614 y = math_ops.multiply(var, 2.0, name="y") 615 g = gradients_impl.gradients(y, var) 616 return g[0] 617 618 f = Foo() 619 620 self.evaluate(variables.global_variables_initializer()) 621 self.assertEqual(self.evaluate(f), 2.0) 622 623 def testCapturedNested(self): 624 with ops.Graph().as_default(): 625 x1 = constant_op.constant(1.0, name="x1") 626 x2 = constant_op.constant(2.0, name="x2") 627 x3 = math_ops.multiply(x1, x2, name="x3") 628 629 @function.defun() 630 def Outer(): 631 outer1 = array_ops.identity(x1, name="outer1") 632 633 @function.defun() 634 def Inner(): 635 inner1 = array_ops.identity(outer1, name="inner1") 636 inner2 = array_ops.identity(x2, name="inner2") 637 inner3 = array_ops.identity(x3, name="inner3") 638 return gradients_impl.gradients([inner1, inner2, inner3, x1], 639 [x1, x2]) 640 641 return Inner() 642 643 x1_grad, x2_grad = Outer() 644 645 # 1.0 + None + 2.0 + 1.0 = 4.0 646 self.assertEqual(self.evaluate(x1_grad), 4.0) 647 # None + 1.0 + 1.0 + None = 2.0 648 self.assertEqual(self.evaluate(x2_grad), 2.0) 649 650 def testCapturedFromFunction(self): 651 with ops.Graph().as_default(): 652 x = constant_op.constant(1.0, name="x") 653 654 @function.defun() 655 def Outer(): 656 y = math_ops.multiply(x, 2.0, name="y") 657 658 @function.defun() 659 def Inner(): 660 z = math_ops.multiply(y, 3.0, name="z") 661 g = gradients_impl.gradients(z, y) 662 return g[0] 663 664 return Inner() 665 666 z_grad = Outer() 667 668 self.assertEqual(self.evaluate(z_grad), 3.0) 669 670 def testCapturedEagerTensors(self): 671 # Test that we can handle captured eager tensors unrelated to the gradient 672 # computation (i.e. we need to ignore them). 673 # TODO(skyewm): make it an error if you try to take the gradient wrt a 674 # captured EagerTensor 675 with context.eager_mode(): 676 c = constant_op.constant(2.0, name="c") 677 678 @function.defun 679 def Foo(): 680 x = constant_op.constant(10.0, name="x") 681 y = math_ops.multiply(x, c, name="y") 682 # Regression test for b/122564611. 683 z = math_ops.multiply(c, y, name="z") 684 g = gradients_impl.gradients(z, x) 685 return g[0] 686 687 self.assertEqual(Foo().numpy(), 4.0) 688 689 690class StopGradientTest(test_util.TensorFlowTestCase): 691 692 def testStopGradient(self): 693 with ops.Graph().as_default(): 694 inp = constant(1.0, shape=[100, 32], name="in") 695 out = array_ops.stop_gradient(inp) 696 igrad = gradients.gradients(out, inp)[0] 697 assert igrad is None 698 699 700class PreventGradientTest(test_util.TensorFlowTestCase): 701 702 def testPreventGradient(self): 703 with ops.Graph().as_default(): 704 inp = constant(1.0, shape=[100, 32], name="in") 705 out = array_ops.prevent_gradient(inp) 706 with self.assertRaisesRegex(LookupError, "explicitly disabled"): 707 _ = gradients.gradients(out, inp) 708 709 710class HessianVectorProductTest(test_util.TensorFlowTestCase): 711 712 @test_util.run_v1_only("b/120545219") 713 def testHessianVectorProduct(self): 714 # Manually compute the Hessian explicitly for a low-dimensional problem 715 # and check that HessianVectorProduct matches multiplication by the 716 # explicit Hessian. 717 # Specifically, the Hessian of f(x) = x^T A x is 718 # H = A + A^T. 719 # We expect HessianVectorProduct(f(x), x, v) to be H v. 720 m = 4 721 rng = np.random.RandomState([1, 2, 3]) 722 mat_value = rng.randn(m, m).astype("float32") 723 v_value = rng.randn(m, 1).astype("float32") 724 x_value = rng.randn(m, 1).astype("float32") 725 hess_value = mat_value + mat_value.T 726 hess_v_value = np.dot(hess_value, v_value) 727 for use_gpu in [False, True]: 728 with self.cached_session(use_gpu=use_gpu): 729 mat = constant_op.constant(mat_value) 730 v = constant_op.constant(v_value) 731 x = constant_op.constant(x_value) 732 mat_x = math_ops.matmul(mat, x, name="Ax") 733 x_mat_x = math_ops.matmul(array_ops.transpose(x), mat_x, name="xAx") 734 hess_v = gradients_impl._hessian_vector_product(x_mat_x, [x], [v])[0] 735 hess_v_actual = self.evaluate(hess_v) 736 self.assertAllClose(hess_v_value, hess_v_actual) 737 738 739class HessianTest(test_util.TensorFlowTestCase): 740 741 @test_util.run_v1_only("b/120545219") 742 def testHessian1D(self): 743 # Manually compute the Hessian explicitly for a low-dimensional problem 744 # and check that `hessian` matches. Specifically, the Hessian of 745 # f(x) = x^T A x is H = A + A^T. 746 m = 4 747 rng = np.random.RandomState([1, 2, 3]) 748 mat_value = rng.randn(m, m).astype("float32") 749 x_value = rng.randn(m).astype("float32") 750 hess_value = mat_value + mat_value.T 751 with self.session(): 752 mat = constant_op.constant(mat_value) 753 x = constant_op.constant(x_value) 754 x_mat_x = math_ops.reduce_sum(x[:, None] * mat * x[None, :]) 755 hess = gradients.hessians(x_mat_x, x)[0] 756 hess_actual = self.evaluate(hess) 757 self.assertAllClose(hess_value, hess_actual) 758 759 @test_util.run_v1_only("b/120545219") 760 def testHessian1D_multi(self): 761 # Test the computation of the hessian with respect to multiple tensors 762 m = 4 763 n = 3 764 rng = np.random.RandomState([1, 2, 3]) 765 mat_values = [rng.randn(m, m).astype("float32") for _ in range(n)] 766 x_values = [rng.randn(m).astype("float32") for _ in range(n)] 767 hess_values = [mat_value + mat_value.T for mat_value in mat_values] 768 with self.session(): 769 mats = [constant_op.constant(mat_value) for mat_value in mat_values] 770 xs = [constant_op.constant(x_value) for x_value in x_values] 771 xs_mats_xs = [ 772 math_ops.reduce_sum(x[:, None] * mat * x[None, :]) 773 for x, mat in zip(xs, mats) 774 ] 775 hessians = gradients.hessians(xs_mats_xs, xs) 776 hessians_actual = [hess.eval() for hess in hessians] 777 for hess_value, hess_actual in zip(hess_values, hessians_actual): 778 self.assertAllClose(hess_value, hess_actual) 779 780 @test_util.run_v1_only("b/120545219") 781 def testHessianInvalidDimension(self): 782 for shape in [(10, 10), None]: 783 with self.cached_session(): 784 x = array_ops.placeholder(dtypes.float32, shape) 785 # Expect a ValueError because the dimensions are wrong 786 with self.assertRaises(ValueError): 787 gradients.hessians(x, x) 788 789 @test_util.run_v1_only("b/120545219") 790 def testHessian2D_square_matrix(self): 791 # Manually compute the Hessian explicitly for a low-dimensional problem 792 # and check that `hessian` matches. Specifically, the Hessian of 793 # f(x) = 1/2 * x^T * x is H = constant (block identity matrix) 794 m = 3 795 rng = np.random.RandomState([1, 2, 3]) 796 x_value = rng.randn(m, m).astype("float32") 797 with self.session(): 798 x = constant_op.constant(x_value) 799 x_square = math_ops.reduce_sum( 800 math_ops.matmul(array_ops.transpose(x), x) * 0.5 801 ) 802 hess = gradients.hessians(x_square, x)[0] 803 hess_actual = self.evaluate(hess) 804 hess_value = np.bmat([ 805 [elem*np.ones((m, m)) for elem in vec] 806 for vec in np.eye(m) 807 ]).astype("float32") 808 self.assertAllEqual((m, m, m, m), hess_actual.shape) 809 self.assertAllClose(hess_value, hess_actual.reshape((m * m, m * m))) 810 811 @test_util.run_v1_only("b/120545219") 812 def testHessian2D_non_square_matrix(self): 813 m = 3 814 n = 4 815 rng = np.random.RandomState([1, 2, 3]) 816 x_value = rng.randn(m, n).astype("float32") 817 with self.session(): 818 x = constant_op.constant(x_value) 819 x_square = math_ops.reduce_sum( 820 math_ops.matmul(array_ops.transpose(x), x) * 0.5 821 ) 822 hess = gradients.hessians(x_square, x)[0] 823 hess_actual = self.evaluate(hess) 824 hess_value = np.bmat([ 825 [elem*np.ones((n, n)) for elem in vec] 826 for vec in np.eye(m) 827 ]).astype("float32") 828 self.assertAllEqual((m, n, m, n), hess_actual.shape) 829 self.assertAllClose(hess_value, hess_actual.reshape((m * n, m * n))) 830 831 832class IndexedSlicesToTensorTest(test_util.TensorFlowTestCase): 833 834 @test_util.run_v1_only("b/120545219") 835 def testIndexedSlicesToTensor(self): 836 with self.cached_session(): 837 np_val = np.random.rand(4, 4, 4, 4).astype(np.float32) 838 c = constant_op.constant(np_val) 839 c_sparse = math_ops._as_indexed_slices(c) 840 self.assertAllEqual(np_val.shape, c_sparse.dense_shape) 841 c_dense = math_ops.multiply(c_sparse, 1.0) 842 self.assertAllClose(np_val, self.evaluate(c_dense)) 843 844 @test_util.run_v1_only("b/120545219") 845 def testIndexedSlicesToTensorList(self): 846 with self.cached_session(): 847 numpy_list = [] 848 dense_list = [] 849 sparse_list = [] 850 for _ in range(3): 851 np_val = np.random.rand(4, 4, 4, 4).astype(np.float32) 852 c = constant_op.constant(np_val) 853 c_sparse = math_ops._as_indexed_slices(c) 854 numpy_list.append(np_val) 855 dense_list.append(c) 856 sparse_list.append(c_sparse) 857 packed_dense = array_ops.stack(dense_list) 858 packed_sparse = array_ops.stack(sparse_list) 859 self.assertAllClose(packed_dense, self.evaluate(packed_sparse)) 860 861 @test_util.run_v1_only("b/120545219") 862 def testInt64Indices(self): 863 with self.cached_session(): 864 np_val = np.random.rand(4, 4, 4, 4).astype(np.float32) 865 c = constant_op.constant(np_val) 866 c_sparse = math_ops._as_indexed_slices(c) 867 c_sparse = indexed_slices.IndexedSlices( 868 c_sparse.values, 869 math_ops.cast(c_sparse.indices, dtypes.int64), c_sparse.dense_shape) 870 self.assertAllEqual(np_val.shape, c_sparse.dense_shape) 871 c_dense = math_ops.multiply(c_sparse, 1.0) 872 self.assertAllClose(np_val, self.evaluate(c_dense)) 873 874 @test_util.run_v1_only("b/120545219") 875 def testWarnings(self): 876 # TODO(gunan) Reenable after this issue is fixed: 877 # https://github.com/google/protobuf/issues/2812 878 if sys.version_info >= (3, 5): 879 self.skipTest("Skipped test for Python 3.5+") 880 881 # Smaller than the threshold: no warning. 882 c_sparse = indexed_slices.IndexedSlices( 883 array_ops.placeholder(dtypes.float32), 884 array_ops.placeholder(dtypes.int32), constant([4, 4, 4, 4])) 885 with warnings.catch_warnings(record=True) as w: 886 math_ops.multiply(c_sparse, 1.0) 887 self.assertEqual(0, len(w)) 888 889 # Greater than or equal to the threshold: warning. 890 c_sparse = indexed_slices.IndexedSlices( 891 array_ops.placeholder(dtypes.float32), 892 array_ops.placeholder(dtypes.int32), constant([100, 100, 100, 100])) 893 # "always" filter prevents the warning from being suppressed if it was 894 # already triggered in a different test. 895 warnings.simplefilter("always") 896 with warnings.catch_warnings(record=True) as w: 897 math_ops.multiply(c_sparse, 1.0) 898 self.assertEqual(1, len(w)) 899 self.assertIn( 900 "with 100000000 elements. This may consume a large amount of memory.", 901 str(w[0].message)) 902 903 # Unknown dense shape: warning. 904 c_sparse = indexed_slices.IndexedSlices( 905 array_ops.placeholder(dtypes.float32), 906 array_ops.placeholder(dtypes.int32), 907 array_ops.placeholder(dtypes.int32)) 908 with warnings.catch_warnings(record=True) as w: 909 math_ops.multiply(c_sparse, 1.0) 910 self.assertEqual(1, len(w)) 911 self.assertIn( 912 "of unknown shape. This may consume a large amount of memory.", 913 str(w[0].message)) 914 915 916class OnlyRealGradientsTest(test_util.TensorFlowTestCase): 917 918 @test_util.run_v1_only("b/120545219") 919 def testRealOnly(self): 920 x = constant_op.constant(7+3j, dtype=dtypes.complex64) 921 y = math_ops.square(x) 922 with self.assertRaisesRegex( 923 TypeError, r"Gradients of complex tensors .* must set grad_ys " 924 r"\(y\.dtype = complex64\)"): 925 gradients.gradients(y, x) 926 927 928class ResourceCondTest(test_util.TensorFlowTestCase): 929 930 @test_util.run_v1_only("b/120545219") 931 def testBasic(self): 932 gamma = resource_variable_ops.ResourceVariable( 933 np.random.random((3,)), 934 dtype="float32", name="gamma") 935 936 inputs = array_ops.ones(shape=(3,), dtype="float32") 937 938 def TestFn(): 939 output = inputs + gamma 940 return output 941 942 training = array_ops.placeholder_with_default(True, shape=()) 943 output = control_flow_ops.cond( 944 training, TestFn, lambda: inputs) 945 946 loss = output 947 948 grads = gradients.gradients( 949 loss, [gamma]) 950 self.assertNotIn(None, grads) 951 952 953class GetDependentVariablesTest(test_util.TensorFlowTestCase): 954 955 def testNoVariables(self): 956 with ops.Graph().as_default(): 957 func = lambda x: array_ops.identity(x) + 5.0 958 input_t = constant_op.constant(2.0) 959 result_t = func(input_t) 960 dependent_vars = custom_gradient._get_dependent_variables( 961 [input_t], [result_t]) 962 963 # There are no variables. 964 self.assertEqual(dependent_vars, []) 965 966 def testVariablesOutside(self): 967 with ops.Graph().as_default(): 968 init = constant_op.constant(100.0) 969 var = variables.Variable(init) 970 971 # The variable is closed over. It should be found. 972 func = lambda x: array_ops.identity(x) + 5.0 + var 973 974 input_t = constant_op.constant(2.0) 975 result_t = func(input_t) 976 dependent_vars = custom_gradient._get_dependent_variables( 977 [input_t], [result_t]) 978 self.assertEqual(dependent_vars, [var]) 979 980 def testVariableSamePrefix(self): 981 with ops.Graph().as_default(): 982 var_name = "my_variable" 983 v_z = variable_scope.get_variable(var_name, shape=()) 984 v_o = variable_scope.get_variable(var_name + "_ones", shape=()) 985 986 # The variable is closed over. It should be found. 987 func = lambda x: array_ops.identity(x) + 5.0 + v_z + v_o 988 989 input_t = constant_op.constant(2.0) 990 result_t = func(input_t) 991 dependent_vars = custom_gradient._get_dependent_variables( 992 [input_t], [result_t]) 993 self.assertEqual(set(dependent_vars), set([v_o, v_z])) 994 995 def testVariablesOutsideButDSeparated(self): 996 with ops.Graph().as_default(): 997 init = constant_op.constant(100.0) 998 var = variables.Variable(init) 999 1000 # The variable is d-separated by the inputs. It should not be found. 1001 input_t = array_ops.identity(var) * 5.0 1002 1003 func = lambda x: array_ops.identity(x) + 5.0 1004 result_t = func(input_t) 1005 dependent_vars = custom_gradient._get_dependent_variables( 1006 [input_t], [result_t]) 1007 self.assertEqual(dependent_vars, []) 1008 1009 def testVariablesOutsideAndNonDifferentiable(self): 1010 with ops.Graph().as_default(): 1011 init = constant_op.constant(100.0, shape=(5,)) 1012 var = variables.Variable(init, shape=(5,)) 1013 1014 def _Func(x): 1015 # non-differentiable dependency on var. 1016 # the variable should not be found. 1017 y = array_ops.ones_like(var) 1018 return array_ops.identity(x) + 5.0 + y 1019 1020 input_t = constant_op.constant(2.0) 1021 result_t = _Func(input_t) 1022 dependent_vars = custom_gradient._get_dependent_variables( 1023 [input_t], [result_t]) 1024 self.assertEqual(dependent_vars, []) 1025 1026 def testGetVariableByName(self): 1027 with context.graph_mode(): 1028 init = constant_op.constant(100.0) 1029 var = variable_scope.variable(init, name="a/replica_1") 1030 if isinstance(var, variables.RefVariable): 1031 var._variable = array_ops.identity(var, name="a") 1032 else: 1033 var._handle = array_ops.identity(var, name="a") 1034 var2 = custom_gradient.get_variable_by_name("a") 1035 self.assertEqual(var2.name, var.name) 1036 1037 def testVariablesOutsideAndNonTrainable(self): 1038 with ops.Graph().as_default(): 1039 init = constant_op.constant(100.0, shape=(5,)) 1040 1041 # Both variables are used in the function but only the trainable one 1042 # should be found. 1043 var_trainable = variables.Variable(init, shape=(5,)) 1044 var_nontrainable = variables.Variable(init, shape=(5,), trainable=False) 1045 1046 def _Func(x): 1047 del x 1048 return var_trainable + var_nontrainable 1049 1050 input_t = constant_op.constant(2.0) 1051 result_t = _Func(input_t) 1052 dependent_vars = custom_gradient._get_dependent_variables( 1053 [input_t], [result_t]) 1054 self.assertEqual(dependent_vars, [var_trainable]) 1055 1056 def testVariablesOutsideAndCustomGradient(self): 1057 with ops.Graph().as_default(): 1058 init = constant_op.constant(100.0, shape=(5,)) 1059 var = variables.Variable(init, shape=(5,)) 1060 1061 @custom_gradient.custom_gradient 1062 def _MyOnesLike(x): 1063 """Dummy version of ones_like which defines a gradient.""" 1064 1065 output = array_ops.ones_like(x) 1066 1067 def _Grad(dy): 1068 return array_ops.identity(dy) 1069 1070 return output, _Grad 1071 1072 def _Func(x): 1073 # non-differentiable operation with custom gradient. 1074 # The variable should be found. 1075 y = _MyOnesLike(var) 1076 return array_ops.identity(x) + 5.0 + y 1077 1078 input_t = constant_op.constant(2.0) 1079 result_t = _Func(input_t) 1080 dependent_vars = custom_gradient._get_dependent_variables( 1081 [input_t], [result_t]) 1082 self.assertEqual(dependent_vars, [var]) 1083 1084 1085class CustomGradientTest(test_util.TensorFlowTestCase, parameterized.TestCase): 1086 1087 def testCustomGradientTrivial(self): 1088 1089 @custom_gradient.custom_gradient 1090 def MyIdentity(x): 1091 1092 def Grad(dy): 1093 return [3 * dy] 1094 1095 return x, Grad 1096 1097 with ops.Graph().as_default(): 1098 x = constant(3.) 1099 y = MyIdentity(MyIdentity(x)) 1100 dy = gradients.gradients(y, x)[0] 1101 with session.Session(): 1102 self.assertEqual(9., self.evaluate(dy)) 1103 1104 def testCustomGradient(self): 1105 1106 @custom_gradient.custom_gradient 1107 def MyMultiply(x1, x2): 1108 result = x1 * x2 1109 1110 def Grad(dy): 1111 # Switched the ordering here. 1112 return [dy * x1, dy * x2] 1113 1114 return result, Grad 1115 1116 with ops.Graph().as_default(): 1117 x1 = constant(3.) 1118 x2 = constant(5.) 1119 y = MyMultiply(x1, x2) 1120 dy = gradients.gradients(y, [x1, x2]) 1121 1122 self.assertAllEqual([3., 5.], self.evaluate(dy)) 1123 1124 def testCustomGradientClass(self): 1125 1126 class Model: 1127 1128 @custom_gradient.custom_gradient 1129 def Multiply(self, x1, x2): 1130 result = x1 * x2 1131 grad = lambda dy: (dy * x1, dy * x2) 1132 return result, grad 1133 1134 with ops.Graph().as_default(): 1135 x1 = constant(3.) 1136 x2 = constant(5.) 1137 m = Model() 1138 y = m.Multiply(x1, x2) 1139 dy = gradients.gradients(y, [x1, x2]) 1140 self.assertAllEqual([3., 5.], self.evaluate(dy)) 1141 1142 def testCustomGradientErrors(self): 1143 1144 @custom_gradient.custom_gradient 1145 def F(x): 1146 1147 def Grad(_): 1148 raise RuntimeError("x") 1149 1150 return x, Grad 1151 1152 with ops.Graph().as_default(): 1153 x = constant(1.0) 1154 y = F(x) 1155 with self.assertRaises(RuntimeError): 1156 gradients.gradients(y, x) 1157 1158 def testCustomGradientWithVariables(self): 1159 1160 @custom_gradient.custom_gradient 1161 def F(x): 1162 out = core_layers.dense(x, 3, use_bias=False) 1163 1164 def Grad(out_grad, variables=None): # pylint: disable=redefined-outer-name 1165 self.assertEqual(1, len(variables)) # pylint: disable=g-generic-assert 1166 grads = gradients.gradients(out, [x, variables[0]], grad_ys=out_grad) 1167 return grads[0], [array_ops.ones((4, 3))] 1168 1169 return out, Grad 1170 1171 with ops.Graph().as_default(): 1172 x = array_ops.ones((2, 4)) 1173 with variable_scope.variable_scope("f", use_resource=True) as vs: 1174 y = F(x) 1175 all_vars = vs.global_variables() 1176 assert len(all_vars) == 1 1177 grads = gradients.gradients(y, [x, all_vars[0]]) 1178 for g in grads: 1179 self.assertIsNotNone(g) 1180 1181 self.evaluate(variables.global_variables_initializer()) 1182 dw = self.evaluate(math_ops.reduce_sum(grads[1])) 1183 self.assertEqual(12., dw) 1184 1185 def testCustomGradientWithCapture(self): 1186 with ops.Graph().as_default(): 1187 x = constant(3.) 1188 1189 @framework_function.Defun(dtypes.float32) 1190 def F(y): 1191 1192 @custom_gradient.custom_gradient 1193 def MyMultiply(x1, x2): 1194 result = x1 * x2 1195 1196 def Grad(dy): 1197 # Switched the ordering here. 1198 return [dy * x1, dy * x2] 1199 1200 return result, Grad 1201 1202 res = MyMultiply(x, y) 1203 return gradients.gradients(res, [y]) 1204 1205 y = constant(5.) 1206 dy = F(y) 1207 self.assertAllEqual(5., self.evaluate(dy)) 1208 1209 def testCustomGradientWithVariablesNoFalsePositives(self): 1210 1211 @custom_gradient.custom_gradient 1212 def F(x): 1213 out = core_layers.dense(x, 3, use_bias=False) 1214 1215 def Grad(out_grad, variables=None): # pylint: disable=redefined-outer-name 1216 self.assertEqual(1, len(variables)) # pylint: disable=g-generic-assert 1217 grads = gradients.gradients(out, [x, variables[0]], grad_ys=out_grad) 1218 return grads[0], [array_ops.ones((3, 3))] 1219 1220 return out, Grad 1221 1222 with ops.Graph().as_default(): 1223 with variable_scope.variable_scope("f", use_resource=True) as vs: 1224 a = array_ops.ones((2, 4)) 1225 1226 # Variabes in these layers shouldn't be picked up by the decorator. 1227 b = core_layers.dense(a, 3, use_bias=False) 1228 c = core_layers.dense(b, 3, use_bias=False) 1229 x = core_layers.dense(b, 3, use_bias=False) + c 1230 1231 # Only the variables used in F. 1232 y = F(x) 1233 1234 all_vars = vs.global_variables() 1235 assert len(all_vars) == 4 1236 grads = gradients.gradients(y, [x] + all_vars) 1237 _, var_grads = grads[0], grads[1:] 1238 for g in grads: 1239 self.assertIsNotNone(g) 1240 1241 self.evaluate(variables.global_variables_initializer()) 1242 dw = self.evaluate(math_ops.reduce_sum(var_grads[-1])) 1243 self.assertEqual(9., dw) 1244 1245 def testCustomGradientWithVariablesEager(self): 1246 with context.eager_mode(): 1247 layer = core_layers.Dense(4, use_bias=False) 1248 1249 @custom_gradient.custom_gradient 1250 def F(x): 1251 out = layer(x) 1252 1253 def Grad(out_grad, variables=None): # pylint: disable=redefined-outer-name 1254 del out_grad 1255 self.assertEqual(1, len(variables)) # pylint: disable=g-generic-assert 1256 return (array_ops.ones((3, 2)), 1257 [array_ops.ones((2, 4))]) 1258 1259 return out, Grad 1260 1261 x = array_ops.ones((3, 2)) + 2. 1262 with backprop.GradientTape() as tape: 1263 tape.watch(x) 1264 y = F(x) 1265 w, = layer.variables 1266 dx, dw = tape.gradient(y, [x, w]) 1267 self.assertEqual(6., math_ops.reduce_sum(dx).numpy()) 1268 self.assertEqual(8., math_ops.reduce_sum(dw).numpy()) 1269 1270 @test_util.run_v1_only("b/120545219") 1271 def testCustomGradientErrorsWithNonResourceVariables(self): 1272 1273 def F(x, use_resource=False): 1274 with variable_scope.variable_scope("f", use_resource=use_resource): 1275 out = core_layers.dense(x, 4, use_bias=False) 1276 1277 def Grad(out_grad, variables=None): # pylint: disable=redefined-outer-name 1278 del out_grad 1279 self.assertEqual(1, len(variables)) # pylint: disable=g-generic-assert 1280 return (array_ops.ones((3, 2)), [array_ops.ones((2, 4))]) 1281 1282 return out, Grad 1283 1284 @custom_gradient.custom_gradient 1285 def FResource(x): 1286 return F(x, use_resource=True) 1287 1288 @custom_gradient.custom_gradient 1289 def FNonResource(x): 1290 return F(x, use_resource=False) 1291 1292 x = array_ops.ones((3, 2)) + 2. 1293 1294 # Wrapping scope has use_resource=True but inner scope sets to False. Fails. 1295 with variable_scope.variable_scope("vs1", use_resource=True): 1296 with self.assertRaisesWithPredicateMatch(TypeError, 1297 "must be `ResourceVariable`s"): 1298 FNonResource(x) 1299 1300 # Wrapping scope has use_resource=False but inner scope sets to True. 1301 # Passes. 1302 with variable_scope.variable_scope("vs2", use_resource=False): 1303 FResource(x) 1304 1305 @parameterized.parameters(True, False) 1306 def testCustomGradientVariablesKwonlyArgs(self, anonymous_varargs): 1307 with context.eager_mode(): 1308 x_captured = variables.Variable(3.) # Used by FuncMult 1309 @custom_gradient.custom_gradient 1310 def FuncMult(x): 1311 def ActualGrad(dy, variables): # pylint: disable=redefined-outer-name 1312 self.assertLen(variables, 1) 1313 self.assertIs(variables[0], x_captured) 1314 x_captured_grad = 5. * x * dy 1315 return (4. * x_captured * dy, [x_captured_grad]) 1316 # Define the returned GradMult, using varargs; "variables" is kwonlyarg 1317 if anonymous_varargs: 1318 def GradMult(dy, *, variables=None): # pylint: disable=redefined-outer-name 1319 return ActualGrad(dy, variables) 1320 else: 1321 def GradMult(*dys, variables=None): # pylint: disable=redefined-outer-name 1322 return ActualGrad(dys[0], variables) 1323 1324 return x * x_captured, GradMult 1325 1326 x = variables.Variable(6.) 1327 with backprop.GradientTape(persistent=True) as g: 1328 y = FuncMult(x) 1329 self.assertAllEqual(g.gradient(y, x), 4. * 3.) 1330 1331 def testWithNumpyInputs(self): 1332 with context.eager_mode(): 1333 1334 @custom_gradient.custom_gradient 1335 def F(x): 1336 out = x 1337 1338 def Grad(_): 1339 return (None, None) 1340 1341 return out, Grad 1342 1343 x = np.ones((3, 2), dtype=np.float32) 1344 # Smoke test to ensure numpy inputs are accepted 1345 F(x) 1346 1347 @test_util.run_v1_only("b/120545219") 1348 def testRVGradientsDynamicCond(self): 1349 with self.cached_session(): 1350 alpha = resource_variable_ops.ResourceVariable( 1351 np.random.random((1,)), 1352 dtype="float32") 1353 1354 conditional = array_ops.placeholder_with_default(True, shape=()) 1355 output = control_flow_ops.cond( 1356 conditional, lambda: alpha * 2, lambda: alpha * 3) 1357 1358 g, = gradients_impl.gradients(output, alpha) 1359 self.evaluate(variables.global_variables_initializer()) 1360 self.assertAllEqual(g, [2.0]) 1361 self.assertAllEqual(g.eval(feed_dict={conditional: False}), [3.0]) 1362 1363 def testRecursiveCustomGradient(self): 1364 @custom_gradient.custom_gradient 1365 def F(x): 1366 out = core_layers.dense(x, 3, use_bias=False) 1367 1368 def Grad(out_grad, variables=None): # pylint: disable=redefined-outer-name 1369 self.assertEqual(1, len(variables)) # pylint: disable=g-generic-assert 1370 grads = gradients.gradients(out, [x, variables[0]], grad_ys=out_grad) 1371 return grads[0], [array_ops.ones((4, 3))] 1372 1373 return out, Grad 1374 1375 @custom_gradient.custom_gradient 1376 def DoubleF(x): 1377 out = F(x) 1378 1379 def Grad(out_grad, variables=None): # pylint: disable=redefined-outer-name 1380 self.assertEqual(1, len(variables)) # pylint: disable=g-generic-assert 1381 grads = gradients.gradients(out, [x, variables[0]], grad_ys=out_grad) 1382 return grads[0], [array_ops.ones((4, 3))] 1383 1384 return out, Grad 1385 with ops.Graph().as_default(): 1386 x = array_ops.ones((2, 4)) 1387 with variable_scope.variable_scope("f", use_resource=True) as vs: 1388 y = DoubleF(x) 1389 all_vars = vs.global_variables() 1390 assert len(all_vars) == 1 1391 grads = gradients.gradients(y, [x, all_vars[0]]) 1392 for g in grads: 1393 self.assertIsNotNone(g) 1394 1395 self.evaluate(variables.global_variables_initializer()) 1396 dw = self.evaluate(math_ops.reduce_sum(grads[1])) 1397 self.assertEqual(12., dw) 1398 1399 @parameterized.named_parameters( 1400 [(("_%s_%s" % (x_struct, y_struct)).replace(" ", "").replace("None", ""), # pylint: disable=g-complex-comprehension 1401 x_struct, y_struct) 1402 for y_struct in [[None, ()], (None, (), [], (None, ((), None)))] 1403 for x_struct in [(None, ()), (((), ()), [None, None], [], (None, ()))] 1404 ]) 1405 @test_util.run_in_graph_and_eager_modes 1406 def testCustomGradientStructuralInputOutput(self, x_struct, y_struct): 1407 """Tests that custom_gradient can handle structured inputs/outputs.""" 1408 def Zeros(x): 1409 return nest.map_structure(lambda _: array_ops.zeros([], "float32"), x) 1410 def GetStruct(x): 1411 return nest.map_structure(lambda _: None, x) 1412 1413 def MakeVjp(f, *x): 1414 with backprop.GradientTape(persistent=True) as tape: 1415 tape.watch(nest.flatten(x)) 1416 y = f(*x) 1417 def Vjp(dy): 1418 return tape.gradient(y, x, output_gradients=dy) 1419 return y, Vjp 1420 1421 @custom_gradient.custom_gradient 1422 def F(*x): 1423 self.assertEqual(x_struct, GetStruct(x)) 1424 def Vjp(*dy): 1425 self.assertEqual(len(nest.flatten(y_struct)), 1426 len(nest.flatten(dy))) 1427 return nest.flatten(Zeros(x_struct)) 1428 return Zeros(y_struct), Vjp 1429 1430 x, dy = Zeros([x_struct, y_struct]) 1431 y, vjp = MakeVjp(F, *x) 1432 dx = vjp(dy) 1433 self.assertEqual(x_struct, GetStruct(dx)) 1434 self.assertEqual(y_struct, GetStruct(y)) 1435 1436 1437class TensorListGradientsTest(test_util.TensorFlowTestCase): 1438 1439 def testDefaultGradYs(self): 1440 with ops.Graph().as_default(): 1441 tl = list_ops.empty_tensor_list( 1442 element_dtype=dtypes.float32, 1443 element_shape=ops.convert_to_tensor([], dtype=dtypes.int32)) 1444 a = constant(1.0) 1445 tl = list_ops.tensor_list_push_back(tl, a) 1446 1447 grad_tl = list_ops.empty_tensor_list( 1448 element_dtype=dtypes.float32, 1449 element_shape=ops.convert_to_tensor([], dtype=dtypes.int32)) 1450 grad_tl = list_ops.tensor_list_push_back(tl, constant(5.0)) 1451 1452 grad = gradients.gradients(tl, a, grad_ys=grad_tl)[0] 1453 1454 self.assertEqual(self.evaluate(grad), 5.) 1455 1456 1457class VariablesGradientTest(test_util.TensorFlowTestCase, 1458 parameterized.TestCase): 1459 1460 def _TestFnVariablesGradient(self, inputs, test_fn, vars_to_grad): 1461 """Returns gradients of `test_model` with respect to `vars_to_grad`.""" 1462 1463 test_fn_re = custom_gradient.recompute_grad(test_fn) 1464 1465 with backprop.GradientTape(persistent=True) as tape: 1466 tape.watch(vars_to_grad) 1467 out_re = test_fn_re(inputs, vars_to_grad) 1468 out = test_fn(inputs, vars_to_grad) 1469 1470 grads_re = tape.gradient(out_re, vars_to_grad) 1471 grads = tape.gradient(out, vars_to_grad) 1472 1473 return grads_re, grads 1474 1475 def _grad(self, f, argnums=0): 1476 """Return a function which computes the gradient of `f`.""" 1477 1478 def F(*params): 1479 with backprop.GradientTape() as tape: 1480 tape.watch(params) 1481 outputs = f(*params) 1482 return tape.gradient( 1483 outputs, 1484 params[argnums], 1485 unconnected_gradients=unconnected_gradients.UnconnectedGradients.ZERO) 1486 1487 return F 1488 1489 def _test_gradients(self, f, inputs, order, delta=1e-3, rtol=1e-2, atol=1e-6): 1490 """Tests backward jacobians of `f`'s [0, `order`)-order gradients.""" 1491 if order < 1: 1492 raise ValueError( 1493 "`order` should be a positive integer, got '{}'.".format(order)) 1494 if order > 1: 1495 self._test_gradients( 1496 f=self._grad(f), 1497 inputs=inputs, 1498 order=order - 1, 1499 delta=delta, 1500 rtol=rtol, 1501 atol=atol) 1502 sym_jac_back, num_jac = gradient_checker_v2.compute_gradient( 1503 f, inputs, delta=delta) 1504 self.assertAllClose(num_jac, sym_jac_back, rtol=rtol, atol=atol) 1505 1506 def testRecomputeGradWrapped(self): 1507 1508 def f(x): # pylint: disable=invalid-name 1509 return 2 * x 1510 1511 g = custom_gradient.recompute_grad(f) 1512 self.assertIs(g.__wrapped__, f) 1513 1514 def testRecomputeGradZeroSizeInput(self): 1515 1516 def F(x): 1517 return 2 * x 1518 1519 x = array_ops.constant(()) 1520 grads_re = self._grad(custom_gradient.recompute_grad(F))(x) 1521 grads = self._grad(F)(x) 1522 self.assertAllClose(grads_re, grads) 1523 1524 f_graph = function.defun(F, input_signature=[tensor_spec.TensorSpec(None)]) 1525 grads_re = self._grad(custom_gradient.recompute_grad(f_graph))(x) 1526 grads = self._grad(f_graph)(x) 1527 self.assertAllClose(grads_re, grads) 1528 1529 def testRecomputeGradDifferentDtypesInputs(self): 1530 1531 def F(x1, x2): 1532 return 2 * x1, 2 * x2 1533 1534 x1 = array_ops.constant(1, dtype=dtypes.int32) 1535 x2 = array_ops.constant(1., dtype=dtypes.float32) 1536 grads_re = self._grad(custom_gradient.recompute_grad(F))(x1, x2) 1537 grads = self._grad(F)(x1, x2) 1538 self.assertAllClose(grads_re, grads) 1539 1540 f_graph = function.defun( 1541 F, 1542 input_signature=[ 1543 tensor_spec.TensorSpec(None, dtype=dtypes.int32), 1544 tensor_spec.TensorSpec(None, dtype=dtypes.float32), 1545 ]) 1546 grads_re = self._grad(custom_gradient.recompute_grad(f_graph))(x1, x2) 1547 grads = self._grad(f_graph)(x1, x2) 1548 self.assertAllClose(grads_re, grads) 1549 1550 @test_util.run_v2_only 1551 def testCustomGradientRecomputeGradHigherOrder(self): 1552 1553 @custom_gradient.recompute_grad 1554 def F(x): 1555 return math_ops.reduce_prod(math_ops.tanh(x)**2) 1556 1557 self._test_gradients(F, [constant_op.constant([1.])], order=3) 1558 1559 @test_util.run_in_graph_and_eager_modes 1560 def testFnRecompute(self): 1561 """Checks that recompute_grad works grads of function args.""" 1562 1563 def TestFn(inputs, input_vars): 1564 return inputs * input_vars 1565 1566 def TestFnSeq(inputs, input_vars): 1567 return (inputs * input_vars, inputs * input_vars * 2.0) 1568 1569 with variable_scope.variable_scope("test", use_resource=True): 1570 test_var = variable_scope.get_variable( 1571 name="test_var", 1572 shape=10, 1573 trainable=True, 1574 ) 1575 self.evaluate(test_var.assign(np.ones([10]))) 1576 test_input = constant(np.ones((10, 10), dtype=np.float32)) 1577 1578 grads_re, grads = self._TestFnVariablesGradient(test_input, TestFn, 1579 test_input) 1580 1581 grads_re = self.evaluate(grads_re) 1582 grads = self.evaluate(grads) 1583 for g, g_re in zip(grads, grads_re): 1584 self.assertAllClose(g, g_re) 1585 1586 grads_re, grads = self._TestFnVariablesGradient(test_input, TestFn, 1587 test_var) 1588 grads_re = self.evaluate(grads_re) 1589 grads = self.evaluate(grads) 1590 for g, g_re in zip(grads, grads_re): 1591 self.assertAllClose(g, g_re) 1592 1593 # Regression test for wrapping sequence outputting functions. 1594 grads_re, grads = self._TestFnVariablesGradient(test_input, TestFnSeq, 1595 test_input) 1596 grads_re = self.evaluate(grads_re) 1597 grads = self.evaluate(grads) 1598 for g, g_re in zip(grads, grads_re): 1599 self.assertAllClose(g, g_re) 1600 1601 grads_re, grads = self._TestFnVariablesGradient(test_input, TestFnSeq, 1602 test_var) 1603 grads_re = self.evaluate(grads_re) 1604 grads = self.evaluate(grads) 1605 for g, g_re in zip(grads, grads_re): 1606 self.assertAllClose(g, g_re) 1607 1608 @parameterized.parameters(set((True, context.executing_eagerly()))) 1609 def testFnRecomputeWithScopeGradient(self, use_tape): 1610 """Checks that recompute_grad works with var scope and GradientTape.""" 1611 1612 def TestFn(input_t): 1613 with variable_scope.variable_scope("inner_scope"): 1614 test_var = variable_scope.get_variable( 1615 name="test_var", 1616 shape=10, 1617 trainable=True, 1618 ) 1619 return input_t * test_var 1620 1621 test_input_t = constant(np.zeros((10, 10), dtype=np.float32)) 1622 1623 with variable_scope.variable_scope( 1624 "output_scope", reuse=variable_scope.AUTO_REUSE, use_resource=True): 1625 test_fn_re = custom_gradient.recompute_grad(TestFn) 1626 1627 with test_util.AbstractGradientTape( 1628 use_tape=use_tape, persistent=True) as tape: 1629 out_re = test_fn_re(test_input_t) 1630 out = TestFn(test_input_t) 1631 1632 self.evaluate(variables.global_variables_initializer()) 1633 grads_re = tape.gradient(out_re, variables.trainable_variables()) 1634 grads = tape.gradient(out, variables.trainable_variables()) 1635 1636 grads_re = self.evaluate(grads_re) 1637 grads = self.evaluate(grads) 1638 for g, g_re in zip(grads, grads_re): 1639 self.assertAllClose(g, g_re) 1640 1641 @test_util.run_in_graph_and_eager_modes 1642 def testFnRecomputeSameTensor(self): 1643 """Check recompute_grad when wrapped f called as f(x, x) - b/147369366.""" 1644 1645 def TestFnMul(x, y): 1646 return x * y 1647 1648 def TestFnSingleVar(x, y): 1649 # pylint: disable=unused-argument 1650 return x 1651 1652 with variable_scope.variable_scope("test", use_resource=True): 1653 x = array_ops.ones((10)) 1654 1655 grads_re, grads = self._TestFnVariablesGradient(x, TestFnMul, 1656 x) 1657 grads_re = self.evaluate(grads_re) 1658 grads = self.evaluate(grads) 1659 for g, g_re in zip(grads, grads_re): 1660 self.assertAllClose(g, g_re) 1661 1662 grads_re, grads = self._TestFnVariablesGradient(x, TestFnSingleVar, 1663 x) 1664 grads_re = self.evaluate(grads_re) 1665 grads = self.evaluate(grads) 1666 for g, g_re in zip(grads, grads_re): 1667 self.assertAllClose(g, g_re) 1668 1669 1670class GradPassThroughTest(test_util.TensorFlowTestCase): 1671 1672 @test_util.run_v1_only("b/120545219") 1673 def test_gradients_v1(self): 1674 x = variable_scope.get_variable( 1675 name="x", shape=(), initializer=init_ops.constant_initializer(1.0), 1676 use_resource=True) 1677 z = variable_scope.get_variable( 1678 name="z", shape=(), initializer=init_ops.constant_initializer(3.0), 1679 use_resource=True) 1680 1681 # Verify that assign op is not differentiable 1682 y = state_ops.assign(x, z**2) 1683 grads = gradients.gradients(y, z) 1684 self.assertIsNone(grads[0]) 1685 1686 # Verify that when the (non differentiable) assign op is wrapped with 1687 # grad_pass_through, gradients are correctly forwarded to the inputs. 1688 # Form an input as quadratic function of variable z and check that the 1689 # gradient of output wrt to z is correct. 1690 y = custom_gradient.grad_pass_through( 1691 lambda v: state_ops.assign(x, v))(z**2) 1692 grads = gradients.gradients(y, z) 1693 1694 with self.cached_session(): 1695 self.evaluate(variables.global_variables_initializer()) 1696 self.assertAllClose(grads[0], 6.0) 1697 1698 # Verify that variables involved in the wrapped op do not receive gradients. 1699 y = custom_gradient.grad_pass_through(lambda v: x * v)(z) 1700 grads = gradients.gradients(y, x) 1701 self.assertIsNone(grads[0]) 1702 1703 @test_util.run_v2_only 1704 def test_gradients_v2(self): 1705 x = variables.Variable(1.0, name="x") 1706 z = variables.Variable(3.0, name="z") 1707 1708 # Verify that assign op is not differentiable 1709 with backprop.GradientTape() as tape: 1710 y = x.assign(z**2) 1711 grads = tape.gradient(y, z) 1712 self.assertIsNone(grads) 1713 1714 # Verify that when the (non differentiable) assign op is wrapped with 1715 # grad_pass_through, gradients are correctly forwarded to the inputs. 1716 # Form an input as quadratic function of variable z and check that the 1717 # gradient of output wrt to z is correct. 1718 with backprop.GradientTape() as tape: 1719 y = custom_gradient.grad_pass_through(x.assign)(z**2) 1720 grads = tape.gradient(y, z) 1721 self.assertAllClose(grads, 6.0) 1722 1723 # Verify that variables involved in the wrapped op do not receive gradients. 1724 with backprop.GradientTape() as tape: 1725 y = custom_gradient.grad_pass_through(lambda v: x * v)(z) 1726 grads = tape.gradient(y, x) 1727 self.assertIsNone(grads) 1728 1729 1730if __name__ == "__main__": 1731 googletest.main() 1732