1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for tensorflow.ops.gradients.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20import sys 21import warnings 22 23from absl.testing import parameterized 24import numpy as np 25from tensorflow.python.client import session 26from tensorflow.python.eager import backprop 27from tensorflow.python.eager import context 28from tensorflow.python.eager import function 29from tensorflow.python.framework import constant_op 30from tensorflow.python.framework import dtypes 31from tensorflow.python.framework import function as framework_function 32from tensorflow.python.framework import ops 33from tensorflow.python.framework import test_ops 34from tensorflow.python.framework import test_util 35from tensorflow.python.framework.constant_op import constant 36from tensorflow.python.layers import core as core_layers 37from tensorflow.python.ops import array_grad # pylint: disable=unused-import 38from tensorflow.python.ops import array_ops 39from tensorflow.python.ops import control_flow_grad # pylint: disable=unused-import 40from tensorflow.python.ops import control_flow_ops 41from tensorflow.python.ops import custom_gradient 42from tensorflow.python.ops import data_flow_grad # pylint: disable=unused-import 43from tensorflow.python.ops import data_flow_ops # pylint: disable=unused-import 44from tensorflow.python.ops import functional_ops # pylint: disable=unused-import 45from tensorflow.python.ops import gradient_checker_v2 46from tensorflow.python.ops import gradients 47from tensorflow.python.ops import gradients_impl 48from tensorflow.python.ops import init_ops 49from tensorflow.python.ops import list_ops 50from tensorflow.python.ops import math_grad # pylint: disable=unused-import 51from tensorflow.python.ops import math_ops 52from tensorflow.python.ops import nn_grad # pylint: disable=unused-import 53from tensorflow.python.ops import resource_variable_ops 54from tensorflow.python.ops import state_grad # pylint: disable=unused-import 55from tensorflow.python.ops import state_ops 56from tensorflow.python.ops import tensor_array_grad # pylint: disable=unused-import 57from tensorflow.python.ops import tensor_array_ops 58from tensorflow.python.ops import unconnected_gradients 59from tensorflow.python.ops import variable_scope 60from tensorflow.python.ops import variables 61from tensorflow.python.ops.nn_ops import bias_add 62from tensorflow.python.platform import googletest 63from tensorflow.python.util import nest 64 65 66class GradientsTest(test_util.TensorFlowTestCase, parameterized.TestCase): 67 68 def testGradients(self): 69 with ops.Graph().as_default(): 70 inp = constant(1.0, shape=[32, 100], name="in") 71 w = constant(1.0, shape=[100, 10], name="w") 72 b = constant(1.0, shape=[10], name="b") 73 xw = math_ops.matmul(inp, w, name="xw") 74 h = bias_add(xw, b, name="h") 75 w_grad = gradients.gradients(h, w)[0] 76 self.assertEqual("MatMul", w_grad.op.type) 77 self.assertEqual(w_grad.op._original_op, xw.op) 78 self.assertTrue(w_grad.op.get_attr("transpose_a")) 79 self.assertFalse(w_grad.op.get_attr("transpose_b")) 80 81 def testUnusedOutput(self): 82 with ops.Graph().as_default(): 83 w = constant(1.0, shape=[2, 2]) 84 x = constant(1.0, shape=[2, 2]) 85 wx = math_ops.matmul(w, x) 86 split_wx = array_ops.split(value=wx, num_or_size_splits=2, axis=0) 87 c = math_ops.reduce_sum(split_wx[1]) 88 gw = gradients.gradients(c, [w])[0] 89 self.assertEqual("MatMul", gw.op.type) 90 91 def testColocateGradients(self): 92 with ops.Graph().as_default() as g: 93 w = constant(1.0, shape=[1, 1]) 94 x = constant(1.0, shape=[1, 2]) 95 with g.device("/device:GPU:0"): 96 wx = math_ops.matmul(w, x) 97 gw = gradients.gradients(wx, [w], colocate_gradients_with_ops=True)[0] 98 self.assertEqual(gw.op.colocation_groups(), wx.op.colocation_groups()) 99 100 def testColocateGradientsWithAggregation(self): 101 with ops.Graph().as_default() as g: 102 with g.device("/device:GPU:1"): 103 w = constant(1.0, shape=[1, 1]) 104 x = constant(1.0, shape=[1, 2]) 105 y = constant(1.0, shape=[1, 2]) 106 wx = math_ops.matmul(w, x) 107 wy = math_ops.matmul(w, y) 108 with g.device("/device:GPU:0"): 109 z = wx + wy 110 111 gw1 = gradients.gradients(z, [w], colocate_gradients_with_ops=True)[0] 112 self.assertEqual(gw1.op.colocation_groups(), wx.op.colocation_groups()) 113 114 gw2 = gradients.gradients(z, [w], colocate_gradients_with_ops=False)[0] 115 self.assertNotEqual(wx.op.colocation_groups(), gw2.op.colocation_groups()) 116 117 def testColocateGradientsWithAggregationInMultipleDevices(self): 118 with ops.Graph().as_default() as g: 119 with g.device("/device:GPU:1"): 120 w = constant(1.0, shape=[1, 1]) 121 x = constant(1.0, shape=[1, 2]) 122 y = constant(1.0, shape=[1, 2]) 123 with g.device("/task:1"): 124 wx = math_ops.matmul(w, x) 125 with g.device("/task:2"): 126 wy = math_ops.matmul(w, y) 127 with g.device("/device:GPU:0"): 128 z = wx + wy 129 130 gw1 = gradients.gradients(z, [w], colocate_gradients_with_ops=True)[0] 131 self.assertEqual(gw1.op.colocation_groups(), w.op.colocation_groups()) 132 133 gw2 = gradients.gradients(z, [w], colocate_gradients_with_ops=False)[0] 134 self.assertNotEqual(w.op.colocation_groups(), gw2.op.colocation_groups()) 135 136 def testColocateGradientsWithGateGradients(self): 137 if not test_util.is_gpu_available(): 138 self.skipTest("No GPU available") 139 with ops.Graph().as_default() as g: 140 with g.device("/device:CPU:0"): 141 x = constant(1.0, shape=[1, 1]) 142 y = constant(1.0, shape=[1, 1]) 143 s = x + y 144 with g.device("/device:GPU:0"): 145 z = math_ops.reduce_sum(s) 146 147 gz_x = gradients.gradients(z, [x], colocate_gradients_with_ops=True, 148 gate_gradients=True)[0] 149 150 # Make sure the placer doesn't complain. 151 self.evaluate(gz_x) 152 153 def testBoundaryStop(self): 154 # Test that we don't differentiate 'x'. The gradient function for 'x' is 155 # set explicitly to None so we will get an exception if the gradient code 156 # tries to differentiate 'x'. 157 with ops.Graph().as_default(): 158 c = constant(1.0) 159 x = array_ops.identity(c) 160 y = x + 1.0 161 z = y + 1 162 grads = gradients.gradients(z, [x]) 163 self.assertTrue(all(x is not None for x in grads)) 164 165 @test_util.run_v1_only("b/120545219") 166 def testBoundaryContinue(self): 167 # Test that we differentiate both 'x' and 'y' correctly when x is a 168 # predecessor of y. 169 with self.cached_session(): 170 x = constant(1.0) 171 y = x * 2.0 172 z = y * 3.0 173 grads = gradients.gradients(z, [x, y]) 174 self.assertTrue(all(x is not None for x in grads)) 175 self.assertEqual(6.0, grads[0].eval()) 176 177 @test_util.run_v1_only("b/120545219") 178 def testAggregationMethodAccumulateN(self): 179 with self.cached_session(): 180 x = constant(1.0) 181 y = x * 2.0 182 z = y + y + y + y + y + y + y + y + y + y 183 grads = gradients.gradients( 184 z, [x, y], 185 aggregation_method=gradients.AggregationMethod. 186 EXPERIMENTAL_ACCUMULATE_N) 187 self.assertTrue(all(x is not None for x in grads)) 188 self.assertEqual(20.0, grads[0].eval()) 189 self.assertEqual(10.0, grads[1].eval()) 190 191 @test_util.run_v1_only("b/120545219") 192 def testAggregationMethodAddN(self): 193 with self.cached_session(): 194 x = constant(1.0) 195 y = x * 2.0 196 z = y + y + y + y + y + y + y + y + y + y 197 grads = gradients.gradients( 198 z, [x, y], aggregation_method=gradients.AggregationMethod.ADD_N) 199 self.assertTrue(all(x is not None for x in grads)) 200 self.assertEqual(20.0, grads[0].eval()) 201 self.assertEqual(10.0, grads[1].eval()) 202 203 @test_util.run_v1_only("b/120545219") 204 def testAggregationMethodTree(self): 205 with self.cached_session(): 206 x = constant(1.0) 207 y = x * 2.0 208 z = y + y + y + y + y + y + y + y + y + y 209 grads = gradients.gradients( 210 z, [x, y], 211 aggregation_method=gradients.AggregationMethod.EXPERIMENTAL_TREE) 212 self.assertTrue(all(x is not None for x in grads)) 213 self.assertEqual(20.0, grads[0].eval()) 214 self.assertEqual(10.0, grads[1].eval()) 215 216 def testNoGradientForStringOutputs(self): 217 with ops.Graph().as_default(): 218 219 def _TestOpGrad(_, float_grad, string_grad): 220 """Gradient function for TestStringOutput.""" 221 self.assertEqual(float_grad.dtype, dtypes.float32) 222 self.assertFalse(string_grad) 223 return float_grad 224 225 ops.RegisterGradient("TestStringOutput")(_TestOpGrad) 226 227 c = constant(1.0) 228 x, _ = test_ops.test_string_output(c) 229 z = x * 2.0 230 w = z * 3.0 231 grads = gradients.gradients(z, [c]) 232 self.assertIsInstance(grads[0], ops.Tensor) 233 grads = gradients.gradients(w, [c]) 234 self.assertIsInstance(grads[0], ops.Tensor) 235 236 def testNoGradientForStringOutputsWithOpNamespace(self): 237 with ops.Graph().as_default(): 238 239 def _TestOpGrad(_, float_grad, string_grad): 240 """Gradient function for TestStringOutput.""" 241 self.assertEqual(float_grad.dtype, dtypes.float32) 242 self.assertFalse(string_grad) 243 return float_grad 244 245 ops.RegisterGradient("Namespace>TestStringOutput")(_TestOpGrad) 246 247 c = constant(1.0) 248 x, _ = test_ops.namespace_test_string_output(c) 249 z = x * 2.0 250 w = z * 3.0 251 grads = gradients.gradients(z, [c]) 252 self.assertIsInstance(grads[0], ops.Tensor) 253 grads = gradients.gradients(w, [c]) 254 self.assertIsInstance(grads[0], ops.Tensor) 255 256 def testSingletonIndexedSlices(self): 257 with ops.Graph().as_default(): 258 x = array_ops.placeholder(dtypes.float32) 259 y = array_ops.identity(x) 260 dy = ops.IndexedSlices( 261 array_ops.placeholder(dtypes.float32), 262 array_ops.placeholder(dtypes.int32)) 263 dx, = gradients.gradients(y, x, grad_ys=dy) 264 # The IndexedSlices gradient of tf.identity is the identity map. 265 with self.cached_session() as sess: 266 vdx, vdy = sess.run( 267 [dx, dy], feed_dict={x: [1.0], dy.indices: [0], dy.values: [2.0]}) 268 self.assertEqual(vdx, vdy) 269 270 @test_util.run_v1_only("b/120545219") 271 def testNonDifferentiableSwitchInWhileLoop(self): 272 with ops.Graph().as_default(): 273 v = array_ops.placeholder(dtypes.float32, []) 274 275 def _Step(i, a, ta): 276 a += math_ops.cast(v, dtypes.int32) 277 return (i + 1, a, ta.write(i, a)) 278 279 n = 4 280 i, _, ta = control_flow_ops.while_loop( 281 lambda i, *_: i < n, 282 _Step, [0, 0, tensor_array_ops.TensorArray( 283 dtypes.int32, size=n)]) 284 target = ta.read(i - 1) 285 grad, = gradients.gradients(target, v) 286 self.assertIsNone(grad) 287 288 def testVariableReadValueGradient(self): 289 with ops.Graph().as_default(): 290 init = constant_op.constant(100.0) 291 var = variables.Variable(init) 292 gradient = gradients.gradients(var.read_value(), var) 293 self.assertIsNotNone(gradient) 294 295 @parameterized.parameters(dtypes.float32, dtypes.float64) 296 def testVariableDefaultGrad(self, dtype): 297 with ops.Graph().as_default(): 298 init = constant_op.constant(100.0, dtype=dtype) 299 var = variables.Variable(init) 300 dummy_const = constant_op.constant(0.0) 301 gradient = gradients.gradients( 302 dummy_const, 303 var, 304 unconnected_gradients=unconnected_gradients.UnconnectedGradients.ZERO 305 )[0] 306 self.assertEqual(gradient.dtype, dtype) 307 self.assertIsNotNone(gradient) 308 309 def testVariableAsGraphElementGradient(self): 310 with ops.Graph().as_default() as graph: 311 init = constant_op.constant(100.0) 312 var = variables.Variable(init) 313 gradient = gradients.gradients(graph.as_graph_element(var), var) 314 self.assertIsNotNone(gradient) 315 316 @test_util.run_v1_only("b/120545219") 317 def testVariableRefGradient(self): 318 with ops.Graph().as_default(): 319 init = constant_op.constant(100.0) 320 var = variables.VariableV1(init) 321 gradient = gradients.gradients(var._ref(), var) 322 self.assertIsNotNone(gradient) 323 324 @test_util.run_v1_only("b/120545219") 325 def testDependentYs(self): 326 with self.cached_session(): 327 x = constant_op.constant(3.0) 328 y = math_ops.square(x) 329 y1 = math_ops.square(y) 330 y2 = math_ops.square(y1) 331 g = gradients.gradients([y, y2], x) 332 self.assertAllClose(17502.0, g[0]) 333 g = gradients.gradients(y + y2, x) 334 self.assertAllClose(17502.0, g[0]) 335 z = array_ops.identity(y) 336 z2 = array_ops.identity(y2) 337 g = gradients.gradients([z, z2], x) 338 self.assertAllClose(17502.0, g[0]) 339 340 @test_util.run_v1_only("b/120545219") 341 def testPartialDerivatives(self): 342 with self.cached_session(): 343 x = constant_op.constant(1.) 344 y = 2 * x 345 z = x + y 346 totalg = gradients.gradients(z, [x, y]) 347 self.assertEqual([3.0, 1.0], [g.eval() for g in totalg]) 348 partialg = gradients.gradients(z, [x, y], stop_gradients=[x, y]) 349 self.assertEqual([1.0, 1.0], [g.eval() for g in partialg]) 350 351 @test_util.run_v1_only("b/120545219") 352 def testStopGradients(self): 353 def _MakeGraph(rng, stop_gradients=()): 354 def _FunctionOf(xs, k=3): 355 return ops.convert_to_tensor( 356 sum(math_ops.matmul(rng.rand(k, k), x) for x in xs) 357 + rng.rand(k, k)) 358 359 a = _FunctionOf([]) 360 if "a" in stop_gradients: a = array_ops.stop_gradient(a) 361 b = _FunctionOf([a]) 362 if "b" in stop_gradients: b = array_ops.stop_gradient(b) 363 c = _FunctionOf([a, b]) 364 if "c" in stop_gradients: c = array_ops.stop_gradient(c) 365 d = _FunctionOf([b, c]) 366 if "d" in stop_gradients: d = array_ops.stop_gradient(d) 367 return dict(a=a, b=b, c=c, d=d) 368 369 def _Gradients(ys, xs, **kwargs): 370 dydxs = gradients.gradients(ys, xs, **kwargs) 371 dydxs = [0. * x if dydx is None else dydx 372 for x, dydx in zip(xs, dydxs)] 373 return dydxs 374 375 seed = np.random.randint(1000) 376 cases = [] 377 subsets = [""] + "a b c d ab ac ad bc bd cd abc abd acd bcd abcd".split() 378 graph = _MakeGraph(np.random.RandomState(seed)) 379 for constants in subsets: 380 graph_with_stops = _MakeGraph(np.random.RandomState(seed), constants) 381 for variables_ in subsets: 382 # compute the gradient when stopped using tf.stop_gradients 383 grad1 = _Gradients([graph_with_stops["d"]], 384 [graph_with_stops[v] for v in variables_]) 385 # compute the gradient when stopped using the stop_gradients kwarg 386 grad2 = _Gradients([graph["d"]], 387 [graph[v] for v in variables_], 388 stop_gradients=[graph[v] for v in constants]) 389 cases.append(dict(grad1=grad1, grad2=grad2, 390 constants=constants, variables=variables_)) 391 392 # evaluate all tensors in one call to session.run for speed 393 with self.cached_session() as sess: 394 results = sess.run([(case["grad1"], case["grad2"]) for case in cases]) 395 396 for (npgrad1, npgrad2), case in zip(results, cases): 397 for a, b in zip(npgrad1, npgrad2): 398 np.testing.assert_allclose(a, b) 399 400 def testUnconnectedGradientsNoneUnconnectedGradients(self): 401 with ops.Graph().as_default(): 402 x = constant(1.0, shape=[2, 2]) 403 y = constant(3.0, shape=[3, 1]) 404 grad = gradients.gradients( 405 [y], [x], unconnected_gradients="none") 406 self.assertIsNone(grad[0]) 407 408 def testUnconnectedGradientsZerosUnconnectedGradients(self): 409 with ops.Graph().as_default(): 410 x = constant(1.0, shape=[2, 2]) 411 y = constant(3.0, shape=[3, 1]) 412 grads = gradients.gradients( 413 [y], [x], unconnected_gradients="zero") 414 415 self.assertAllEqual([[0.0, 0.0], [0.0, 0.0]], self.evaluate(grads)[0]) 416 417 def testUnconnectedGradientsZeroConnectedGradients(self): 418 with ops.Graph().as_default(): 419 x = constant(1.0) 420 y = x * 3.0 421 grad = gradients.gradients( 422 [y], [x], unconnected_gradients="zero") 423 424 self.assertEqual(3.0, self.evaluate(grad)[0]) 425 426 def testUnknownUnconnectedGradientsValueGiven(self): 427 with ops.Graph().as_default(): 428 x = constant(1.0) 429 y = constant(1.0) 430 with self.assertRaisesRegex( 431 ValueError, "Unknown value for unconnected_gradients: 'nonsense'"): 432 gradients.gradients([y], [x], unconnected_gradients="nonsense") 433 434 @parameterized.parameters(unconnected_gradients.UnconnectedGradients.ZERO, 435 unconnected_gradients.UnconnectedGradients.NONE) 436 def testUnconnectedOpWithMultipleOutputs(self, unconnected_gradients_val): 437 with ops.Graph().as_default(): 438 # a b 439 # | | 440 # IdentityN 441 # | | 442 # c d 443 # | 444 # Identity 445 # | 446 # e 447 a = constant_op.constant(1.0) 448 b = constant_op.constant(1.0) 449 c, d = array_ops.identity_n([a, b]) 450 e = array_ops.identity(c) 451 # The aggregated grads for the IdentityN node would look like 452 # [Tensor, None]. We expect this None to be converted to zeros. 453 output = gradients.gradients( 454 e, d, unconnected_gradients=unconnected_gradients_val) 455 if (unconnected_gradients_val == 456 unconnected_gradients.UnconnectedGradients.ZERO): 457 self.assertIsNotNone(output[0]) 458 else: 459 self.assertIsNone(output[0]) 460 461 @parameterized.parameters(unconnected_gradients.UnconnectedGradients.ZERO, 462 unconnected_gradients.UnconnectedGradients.NONE) 463 def testUnconnectedOpWithMultipleOutputsStopGradient( 464 self, unconnected_gradients_val): 465 with ops.Graph().as_default(): 466 # a b 467 # | | 468 # IdentityN 469 # | | 470 # c d 471 # | | 472 # SG | 473 # | | 474 # \ / 475 # + 476 # e 477 a = constant_op.constant(1.0) 478 b = constant_op.constant(1.0) 479 c, d = array_ops.identity_n([a, b]) 480 e = array_ops.stop_gradient(c) + d 481 # The aggregated grads for the IdentityN node would look like 482 # [None, Tensor]. We expect this None to be converted to zeros. 483 output = gradients.gradients( 484 e, c, unconnected_gradients=unconnected_gradients_val) 485 if (unconnected_gradients_val == 486 unconnected_gradients.UnconnectedGradients.ZERO): 487 self.assertIsNotNone(output[0]) 488 else: 489 self.assertIsNone(output[0]) 490 491 492class FunctionGradientsTest(test_util.TensorFlowTestCase): 493 494 @classmethod 495 def XSquarePlusB(cls, x, b): 496 return x * x + b 497 498 @classmethod 499 def XSquarePlusBGradient(cls, x, b, g): 500 # Perturb gradients (multiply by 2), so we can test that this was called. 501 g *= 2.0 502 return g * 2.0 * x, g 503 504 @classmethod 505 def _PythonGradient(cls, op, grad): 506 # Perturb gradients (multiply by 3), so we can test that this was called. 507 grad *= 3.0 508 return grad * op.inputs[0] * 2.0, grad 509 510 @classmethod 511 def _GetFunc(cls, **kwargs): 512 return framework_function.Defun(dtypes.float32, dtypes.float32, ** 513 kwargs)(cls.XSquarePlusB) 514 515 def _GetFuncGradients(self, f, x_value, b_value): 516 x = constant_op.constant(x_value, name="x") 517 b = constant_op.constant(b_value, name="b") 518 519 y = f(x, b) 520 grads = gradients.gradients(y, [x, b]) 521 522 return self.evaluate(grads) 523 524 def testFunctionGradientsBasic(self): 525 g = ops.Graph() 526 with g.as_default(): 527 f = self._GetFunc() 528 # Get gradients (should add SymbolicGradient node for function). 529 grads = self._GetFuncGradients(f, [2.0], [1.0]) 530 self.assertAllEqual([4.0], grads[0]) 531 self.assertAllEqual([1.0], grads[1]) 532 533 def testFunctionGradientsComposition(self): 534 with ops.Graph().as_default(): 535 f = self._GetFunc() 536 x = constant_op.constant([2.0], name="x") 537 b1 = constant_op.constant([1.0], name="b1") 538 b2 = constant_op.constant([1.0], name="b2") 539 540 y = f(f(x, b1), b2) 541 # Build gradient graph (should add SymbolicGradient node for function). 542 grads = gradients.gradients(y, [x, b1]) 543 544 self.assertAllEqual([40.0], self.evaluate(grads)[0]) 545 self.assertAllEqual([10.0], self.evaluate(grads)[1]) 546 547 def testFunctionGradientsWithGradFunc(self): 548 g = ops.Graph() 549 with g.as_default(): 550 grad_func = framework_function.Defun(dtypes.float32, dtypes.float32, 551 dtypes.float32)( 552 self.XSquarePlusBGradient) 553 f = self._GetFunc(grad_func=grad_func) 554 # Get gradients (should add SymbolicGradient node for function, which 555 # uses the grad_func above, which multiplies all gradients by 2). 556 grads = self._GetFuncGradients(f, [2.0], [1.0]) 557 self.assertAllEqual([4.0 * 2], grads[0]) 558 self.assertAllEqual([1.0 * 2], grads[1]) 559 560 def testFunctionGradientWithRegistration(self): 561 g = ops.Graph() 562 with g.as_default(): 563 f = self._GetFunc(python_grad_func=self._PythonGradient) 564 # Get gradients, using the python gradient function. It multiplies the 565 # gradients by 3. 566 grads = self._GetFuncGradients(f, [2.0], [1.0]) 567 self.assertAllEqual([4.0 * 3], grads[0]) 568 self.assertAllEqual([1.0 * 3], grads[1]) 569 570 def testFunctionGradientWithGradFuncAndRegistration(self): 571 g = ops.Graph() 572 with g.as_default(): 573 grad_func = framework_function.Defun(dtypes.float32, dtypes.float32, 574 dtypes.float32)( 575 self.XSquarePlusBGradient) 576 with self.assertRaisesRegex(ValueError, "Gradient defined twice"): 577 f = self._GetFunc( 578 grad_func=grad_func, python_grad_func=self._PythonGradient) 579 f.add_to_graph(ops.Graph()) 580 581 def testGradientWrtCaptured(self): 582 with ops.Graph().as_default(): 583 x = constant_op.constant(1.0, name="x") 584 585 @function.defun() 586 def Foo(): 587 y = math_ops.multiply(x, 2.0, name="y") 588 g = gradients_impl.gradients(y, x) 589 return g[0] 590 591 f = Foo() 592 593 self.assertEqual(self.evaluate(f), 2.0) 594 595 def testGradientOfCaptured(self): 596 with ops.Graph().as_default(): 597 x = constant_op.constant(1.0, name="x") 598 y = math_ops.multiply(x, 2.0, name="y") 599 600 @framework_function.Defun() 601 def Foo(): 602 g = gradients_impl.gradients(y, x) 603 return g[0] 604 605 f = Foo() 606 607 self.assertEqual(self.evaluate(f), 2.0) 608 609 def testCapturedResourceVariable(self): 610 with ops.Graph().as_default(): 611 var = resource_variable_ops.ResourceVariable(1.0, name="var") 612 613 @function.defun() 614 def Foo(): 615 y = math_ops.multiply(var, 2.0, name="y") 616 g = gradients_impl.gradients(y, var) 617 return g[0] 618 619 f = Foo() 620 621 self.evaluate(variables.global_variables_initializer()) 622 self.assertEqual(self.evaluate(f), 2.0) 623 624 def testCapturedNested(self): 625 with ops.Graph().as_default(): 626 x1 = constant_op.constant(1.0, name="x1") 627 x2 = constant_op.constant(2.0, name="x2") 628 x3 = math_ops.multiply(x1, x2, name="x3") 629 630 @function.defun() 631 def Outer(): 632 outer1 = array_ops.identity(x1, name="outer1") 633 634 @function.defun() 635 def Inner(): 636 inner1 = array_ops.identity(outer1, name="inner1") 637 inner2 = array_ops.identity(x2, name="inner2") 638 inner3 = array_ops.identity(x3, name="inner3") 639 return gradients_impl.gradients([inner1, inner2, inner3, x1], 640 [x1, x2]) 641 642 return Inner() 643 644 x1_grad, x2_grad = Outer() 645 646 # 1.0 + None + 2.0 + 1.0 = 4.0 647 self.assertEqual(self.evaluate(x1_grad), 4.0) 648 # None + 1.0 + 1.0 + None = 2.0 649 self.assertEqual(self.evaluate(x2_grad), 2.0) 650 651 def testCapturedFromFunction(self): 652 with ops.Graph().as_default(): 653 x = constant_op.constant(1.0, name="x") 654 655 @function.defun() 656 def Outer(): 657 y = math_ops.multiply(x, 2.0, name="y") 658 659 @function.defun() 660 def Inner(): 661 z = math_ops.multiply(y, 3.0, name="z") 662 g = gradients_impl.gradients(z, y) 663 return g[0] 664 665 return Inner() 666 667 z_grad = Outer() 668 669 self.assertEqual(self.evaluate(z_grad), 3.0) 670 671 def testCapturedEagerTensors(self): 672 # Test that we can handle captured eager tensors unrelated to the gradient 673 # computation (i.e. we need to ignore them). 674 # TODO(skyewm): make it an error if you try to take the gradient wrt a 675 # captured EagerTensor 676 with context.eager_mode(): 677 c = constant_op.constant(2.0, name="c") 678 679 @function.defun 680 def Foo(): 681 x = constant_op.constant(10.0, name="x") 682 y = math_ops.multiply(x, c, name="y") 683 # Regression test for b/122564611. 684 z = math_ops.multiply(c, y, name="z") 685 g = gradients_impl.gradients(z, x) 686 return g[0] 687 688 self.assertEqual(Foo().numpy(), 4.0) 689 690 691class StopGradientTest(test_util.TensorFlowTestCase): 692 693 def testStopGradient(self): 694 with ops.Graph().as_default(): 695 inp = constant(1.0, shape=[100, 32], name="in") 696 out = array_ops.stop_gradient(inp) 697 igrad = gradients.gradients(out, inp)[0] 698 assert igrad is None 699 700 701class PreventGradientTest(test_util.TensorFlowTestCase): 702 703 def testPreventGradient(self): 704 with ops.Graph().as_default(): 705 inp = constant(1.0, shape=[100, 32], name="in") 706 out = array_ops.prevent_gradient(inp) 707 with self.assertRaisesRegex(LookupError, "explicitly disabled"): 708 _ = gradients.gradients(out, inp) 709 710 711class HessianVectorProductTest(test_util.TensorFlowTestCase): 712 713 @test_util.run_v1_only("b/120545219") 714 def testHessianVectorProduct(self): 715 # Manually compute the Hessian explicitly for a low-dimensional problem 716 # and check that HessianVectorProduct matches multiplication by the 717 # explicit Hessian. 718 # Specifically, the Hessian of f(x) = x^T A x is 719 # H = A + A^T. 720 # We expect HessianVectorProduct(f(x), x, v) to be H v. 721 m = 4 722 rng = np.random.RandomState([1, 2, 3]) 723 mat_value = rng.randn(m, m).astype("float32") 724 v_value = rng.randn(m, 1).astype("float32") 725 x_value = rng.randn(m, 1).astype("float32") 726 hess_value = mat_value + mat_value.T 727 hess_v_value = np.dot(hess_value, v_value) 728 for use_gpu in [False, True]: 729 with self.cached_session(use_gpu=use_gpu): 730 mat = constant_op.constant(mat_value) 731 v = constant_op.constant(v_value) 732 x = constant_op.constant(x_value) 733 mat_x = math_ops.matmul(mat, x, name="Ax") 734 x_mat_x = math_ops.matmul(array_ops.transpose(x), mat_x, name="xAx") 735 hess_v = gradients_impl._hessian_vector_product(x_mat_x, [x], [v])[0] 736 hess_v_actual = self.evaluate(hess_v) 737 self.assertAllClose(hess_v_value, hess_v_actual) 738 739 740class HessianTest(test_util.TensorFlowTestCase): 741 742 @test_util.run_v1_only("b/120545219") 743 def testHessian1D(self): 744 # Manually compute the Hessian explicitly for a low-dimensional problem 745 # and check that `hessian` matches. Specifically, the Hessian of 746 # f(x) = x^T A x is H = A + A^T. 747 m = 4 748 rng = np.random.RandomState([1, 2, 3]) 749 mat_value = rng.randn(m, m).astype("float32") 750 x_value = rng.randn(m).astype("float32") 751 hess_value = mat_value + mat_value.T 752 with self.session(): 753 mat = constant_op.constant(mat_value) 754 x = constant_op.constant(x_value) 755 x_mat_x = math_ops.reduce_sum(x[:, None] * mat * x[None, :]) 756 hess = gradients.hessians(x_mat_x, x)[0] 757 hess_actual = self.evaluate(hess) 758 self.assertAllClose(hess_value, hess_actual) 759 760 @test_util.run_v1_only("b/120545219") 761 def testHessian1D_multi(self): 762 # Test the computation of the hessian with respect to multiple tensors 763 m = 4 764 n = 3 765 rng = np.random.RandomState([1, 2, 3]) 766 mat_values = [rng.randn(m, m).astype("float32") for _ in range(n)] 767 x_values = [rng.randn(m).astype("float32") for _ in range(n)] 768 hess_values = [mat_value + mat_value.T for mat_value in mat_values] 769 with self.session(): 770 mats = [constant_op.constant(mat_value) for mat_value in mat_values] 771 xs = [constant_op.constant(x_value) for x_value in x_values] 772 xs_mats_xs = [ 773 math_ops.reduce_sum(x[:, None] * mat * x[None, :]) 774 for x, mat in zip(xs, mats) 775 ] 776 hessians = gradients.hessians(xs_mats_xs, xs) 777 hessians_actual = [hess.eval() for hess in hessians] 778 for hess_value, hess_actual in zip(hess_values, hessians_actual): 779 self.assertAllClose(hess_value, hess_actual) 780 781 @test_util.run_v1_only("b/120545219") 782 def testHessianInvalidDimension(self): 783 for shape in [(10, 10), None]: 784 with self.cached_session(): 785 x = array_ops.placeholder(dtypes.float32, shape) 786 # Expect a ValueError because the dimensions are wrong 787 with self.assertRaises(ValueError): 788 gradients.hessians(x, x) 789 790 @test_util.run_v1_only("b/120545219") 791 def testHessian2D_square_matrix(self): 792 # Manually compute the Hessian explicitly for a low-dimensional problem 793 # and check that `hessian` matches. Specifically, the Hessian of 794 # f(x) = 1/2 * x^T * x is H = constant (block identity matrix) 795 m = 3 796 rng = np.random.RandomState([1, 2, 3]) 797 x_value = rng.randn(m, m).astype("float32") 798 with self.session(): 799 x = constant_op.constant(x_value) 800 x_square = math_ops.reduce_sum( 801 math_ops.matmul(array_ops.transpose(x), x) * 0.5 802 ) 803 hess = gradients.hessians(x_square, x)[0] 804 hess_actual = self.evaluate(hess) 805 hess_value = np.bmat([ 806 [elem*np.ones((m, m)) for elem in vec] 807 for vec in np.eye(m) 808 ]).astype("float32") 809 self.assertAllEqual((m, m, m, m), hess_actual.shape) 810 self.assertAllClose(hess_value, hess_actual.reshape((m * m, m * m))) 811 812 @test_util.run_v1_only("b/120545219") 813 def testHessian2D_non_square_matrix(self): 814 m = 3 815 n = 4 816 rng = np.random.RandomState([1, 2, 3]) 817 x_value = rng.randn(m, n).astype("float32") 818 with self.session(): 819 x = constant_op.constant(x_value) 820 x_square = math_ops.reduce_sum( 821 math_ops.matmul(array_ops.transpose(x), x) * 0.5 822 ) 823 hess = gradients.hessians(x_square, x)[0] 824 hess_actual = self.evaluate(hess) 825 hess_value = np.bmat([ 826 [elem*np.ones((n, n)) for elem in vec] 827 for vec in np.eye(m) 828 ]).astype("float32") 829 self.assertAllEqual((m, n, m, n), hess_actual.shape) 830 self.assertAllClose(hess_value, hess_actual.reshape((m * n, m * n))) 831 832 833class IndexedSlicesToTensorTest(test_util.TensorFlowTestCase): 834 835 @test_util.run_v1_only("b/120545219") 836 def testIndexedSlicesToTensor(self): 837 with self.cached_session(): 838 np_val = np.random.rand(4, 4, 4, 4).astype(np.float32) 839 c = constant_op.constant(np_val) 840 c_sparse = math_ops._as_indexed_slices(c) 841 self.assertAllEqual(np_val.shape, c_sparse.dense_shape) 842 c_dense = math_ops.multiply(c_sparse, 1.0) 843 self.assertAllClose(np_val, self.evaluate(c_dense)) 844 845 @test_util.run_v1_only("b/120545219") 846 def testIndexedSlicesToTensorList(self): 847 with self.cached_session(): 848 numpy_list = [] 849 dense_list = [] 850 sparse_list = [] 851 for _ in range(3): 852 np_val = np.random.rand(4, 4, 4, 4).astype(np.float32) 853 c = constant_op.constant(np_val) 854 c_sparse = math_ops._as_indexed_slices(c) 855 numpy_list.append(np_val) 856 dense_list.append(c) 857 sparse_list.append(c_sparse) 858 packed_dense = array_ops.stack(dense_list) 859 packed_sparse = array_ops.stack(sparse_list) 860 self.assertAllClose(packed_dense, self.evaluate(packed_sparse)) 861 862 @test_util.run_v1_only("b/120545219") 863 def testInt64Indices(self): 864 with self.cached_session(): 865 np_val = np.random.rand(4, 4, 4, 4).astype(np.float32) 866 c = constant_op.constant(np_val) 867 c_sparse = math_ops._as_indexed_slices(c) 868 c_sparse = ops.IndexedSlices( 869 c_sparse.values, 870 math_ops.cast(c_sparse.indices, dtypes.int64), c_sparse.dense_shape) 871 self.assertAllEqual(np_val.shape, c_sparse.dense_shape) 872 c_dense = math_ops.multiply(c_sparse, 1.0) 873 self.assertAllClose(np_val, self.evaluate(c_dense)) 874 875 @test_util.run_v1_only("b/120545219") 876 def testWarnings(self): 877 # TODO(gunan) Reenable after this issue is fixed: 878 # https://github.com/google/protobuf/issues/2812 879 if sys.version_info >= (3, 5): 880 self.skipTest("Skipped test for Python 3.5+") 881 882 # Smaller than the threshold: no warning. 883 c_sparse = ops.IndexedSlices( 884 array_ops.placeholder(dtypes.float32), 885 array_ops.placeholder(dtypes.int32), constant([4, 4, 4, 4])) 886 with warnings.catch_warnings(record=True) as w: 887 math_ops.multiply(c_sparse, 1.0) 888 self.assertEqual(0, len(w)) 889 890 # Greater than or equal to the threshold: warning. 891 c_sparse = ops.IndexedSlices( 892 array_ops.placeholder(dtypes.float32), 893 array_ops.placeholder(dtypes.int32), constant([100, 100, 100, 100])) 894 # "always" filter prevents the warning from being suppressed if it was 895 # already triggered in a different test. 896 warnings.simplefilter("always") 897 with warnings.catch_warnings(record=True) as w: 898 math_ops.multiply(c_sparse, 1.0) 899 self.assertEqual(1, len(w)) 900 self.assertIn( 901 "with 100000000 elements. This may consume a large amount of memory.", 902 str(w[0].message)) 903 904 # Unknown dense shape: warning. 905 c_sparse = ops.IndexedSlices( 906 array_ops.placeholder(dtypes.float32), 907 array_ops.placeholder(dtypes.int32), 908 array_ops.placeholder(dtypes.int32)) 909 with warnings.catch_warnings(record=True) as w: 910 math_ops.multiply(c_sparse, 1.0) 911 self.assertEqual(1, len(w)) 912 self.assertIn( 913 "of unknown shape. This may consume a large amount of memory.", 914 str(w[0].message)) 915 916 917class OnlyRealGradientsTest(test_util.TensorFlowTestCase): 918 919 @test_util.run_v1_only("b/120545219") 920 def testRealOnly(self): 921 x = constant_op.constant(7+3j, dtype=dtypes.complex64) 922 y = math_ops.square(x) 923 with self.assertRaisesRegex( 924 TypeError, r"Gradients of complex tensors must set grad_ys " 925 r"\(y\.dtype = tf\.complex64\)"): 926 gradients.gradients(y, x) 927 928 929class ResourceCondTest(test_util.TensorFlowTestCase): 930 931 @test_util.run_v1_only("b/120545219") 932 def testBasic(self): 933 gamma = resource_variable_ops.ResourceVariable( 934 np.random.random((3,)), 935 dtype="float32", name="gamma") 936 937 inputs = array_ops.ones(shape=(3,), dtype="float32") 938 939 def TestFn(): 940 output = inputs + gamma 941 return output 942 943 training = array_ops.placeholder_with_default(True, shape=()) 944 output = control_flow_ops.cond( 945 training, TestFn, lambda: inputs) 946 947 loss = output 948 949 grads = gradients.gradients( 950 loss, [gamma]) 951 self.assertNotIn(None, grads) 952 953 954class GetDependentVariablesTest(test_util.TensorFlowTestCase): 955 956 def testNoVariables(self): 957 with ops.Graph().as_default(): 958 func = lambda x: array_ops.identity(x) + 5.0 959 input_t = constant_op.constant(2.0) 960 result_t = func(input_t) 961 dependent_vars = custom_gradient._get_dependent_variables( 962 [input_t], [result_t]) 963 964 # There are no variables. 965 self.assertEqual(dependent_vars, []) 966 967 def testVariablesOutside(self): 968 with ops.Graph().as_default(): 969 init = constant_op.constant(100.0) 970 var = variables.Variable(init) 971 972 # The variable is closed over. It should be found. 973 func = lambda x: array_ops.identity(x) + 5.0 + var 974 975 input_t = constant_op.constant(2.0) 976 result_t = func(input_t) 977 dependent_vars = custom_gradient._get_dependent_variables( 978 [input_t], [result_t]) 979 self.assertEqual(dependent_vars, [var]) 980 981 def testVariableSamePrefix(self): 982 with ops.Graph().as_default(): 983 var_name = "my_variable" 984 v_z = variable_scope.get_variable(var_name, shape=()) 985 v_o = variable_scope.get_variable(var_name + "_ones", shape=()) 986 987 # The variable is closed over. It should be found. 988 func = lambda x: array_ops.identity(x) + 5.0 + v_z + v_o 989 990 input_t = constant_op.constant(2.0) 991 result_t = func(input_t) 992 dependent_vars = custom_gradient._get_dependent_variables( 993 [input_t], [result_t]) 994 self.assertEqual(set(dependent_vars), set([v_o, v_z])) 995 996 def testVariablesOutsideButDSeparated(self): 997 with ops.Graph().as_default(): 998 init = constant_op.constant(100.0) 999 var = variables.Variable(init) 1000 1001 # The variable is d-separated by the inputs. It should not be found. 1002 input_t = array_ops.identity(var) * 5.0 1003 1004 func = lambda x: array_ops.identity(x) + 5.0 1005 result_t = func(input_t) 1006 dependent_vars = custom_gradient._get_dependent_variables( 1007 [input_t], [result_t]) 1008 self.assertEqual(dependent_vars, []) 1009 1010 def testVariablesOutsideAndNonDifferentiable(self): 1011 with ops.Graph().as_default(): 1012 init = constant_op.constant(100.0, shape=(5,)) 1013 var = variables.Variable(init, shape=(5,)) 1014 1015 def _Func(x): 1016 # non-differentiable dependency on var. 1017 # the variable should not be found. 1018 y = array_ops.ones_like(var) 1019 return array_ops.identity(x) + 5.0 + y 1020 1021 input_t = constant_op.constant(2.0) 1022 result_t = _Func(input_t) 1023 dependent_vars = custom_gradient._get_dependent_variables( 1024 [input_t], [result_t]) 1025 self.assertEqual(dependent_vars, []) 1026 1027 def testVariablesOutsideAndNonTrainable(self): 1028 with ops.Graph().as_default(): 1029 init = constant_op.constant(100.0, shape=(5,)) 1030 1031 # Both variables are used in the function but only the trainable one 1032 # should be found. 1033 var_trainable = variables.Variable(init, shape=(5,)) 1034 var_nontrainable = variables.Variable(init, shape=(5,), trainable=False) 1035 1036 def _Func(x): 1037 del x 1038 return var_trainable + var_nontrainable 1039 1040 input_t = constant_op.constant(2.0) 1041 result_t = _Func(input_t) 1042 dependent_vars = custom_gradient._get_dependent_variables( 1043 [input_t], [result_t]) 1044 self.assertEqual(dependent_vars, [var_trainable]) 1045 1046 def testVariablesOutsideAndCustomGradient(self): 1047 with ops.Graph().as_default(): 1048 init = constant_op.constant(100.0, shape=(5,)) 1049 var = variables.Variable(init, shape=(5,)) 1050 1051 @custom_gradient.custom_gradient 1052 def _MyOnesLike(x): 1053 """Dummy version of ones_like which defines a gradient.""" 1054 1055 output = array_ops.ones_like(x) 1056 1057 def _Grad(dy): 1058 return array_ops.identity(dy) 1059 1060 return output, _Grad 1061 1062 def _Func(x): 1063 # non-differentiable operation with custom gradient. 1064 # The variable should be found. 1065 y = _MyOnesLike(var) 1066 return array_ops.identity(x) + 5.0 + y 1067 1068 input_t = constant_op.constant(2.0) 1069 result_t = _Func(input_t) 1070 dependent_vars = custom_gradient._get_dependent_variables( 1071 [input_t], [result_t]) 1072 self.assertEqual(dependent_vars, [var]) 1073 1074 1075class CustomGradientTest(test_util.TensorFlowTestCase, parameterized.TestCase): 1076 1077 def testCustomGradientTrivial(self): 1078 1079 @custom_gradient.custom_gradient 1080 def MyIdentity(x): 1081 1082 def Grad(dy): 1083 return [3 * dy] 1084 1085 return x, Grad 1086 1087 with ops.Graph().as_default(): 1088 x = constant(3.) 1089 y = MyIdentity(MyIdentity(x)) 1090 dy = gradients.gradients(y, x)[0] 1091 with session.Session(): 1092 self.assertEqual(9., self.evaluate(dy)) 1093 1094 def testCustomGradient(self): 1095 1096 @custom_gradient.custom_gradient 1097 def MyMultiply(x1, x2): 1098 result = x1 * x2 1099 1100 def Grad(dy): 1101 # Switched the ordering here. 1102 return [dy * x1, dy * x2] 1103 1104 return result, Grad 1105 1106 with ops.Graph().as_default(): 1107 x1 = constant(3.) 1108 x2 = constant(5.) 1109 y = MyMultiply(x1, x2) 1110 dy = gradients.gradients(y, [x1, x2]) 1111 1112 self.assertAllEqual([3., 5.], self.evaluate(dy)) 1113 1114 def testCustomGradientClass(self): 1115 1116 class Model(object): 1117 1118 @custom_gradient.custom_gradient 1119 def Multiply(self, x1, x2): 1120 result = x1 * x2 1121 grad = lambda dy: (dy * x1, dy * x2) 1122 return result, grad 1123 1124 with ops.Graph().as_default(): 1125 x1 = constant(3.) 1126 x2 = constant(5.) 1127 m = Model() 1128 y = m.Multiply(x1, x2) 1129 dy = gradients.gradients(y, [x1, x2]) 1130 self.assertAllEqual([3., 5.], self.evaluate(dy)) 1131 1132 def testCustomGradientErrors(self): 1133 1134 @custom_gradient.custom_gradient 1135 def F(x): 1136 1137 def Grad(_): 1138 raise RuntimeError("x") 1139 1140 return x, Grad 1141 1142 with ops.Graph().as_default(): 1143 x = constant(1.0) 1144 y = F(x) 1145 with self.assertRaises(RuntimeError): 1146 gradients.gradients(y, x) 1147 1148 def testCustomGradientWithVariables(self): 1149 1150 @custom_gradient.custom_gradient 1151 def F(x): 1152 out = core_layers.dense(x, 3, use_bias=False) 1153 1154 def Grad(out_grad, variables=None): # pylint: disable=redefined-outer-name 1155 self.assertEqual(1, len(variables)) # pylint: disable=g-generic-assert 1156 grads = gradients.gradients(out, [x, variables[0]], grad_ys=out_grad) 1157 return grads[0], [array_ops.ones((4, 3))] 1158 1159 return out, Grad 1160 1161 with ops.Graph().as_default(): 1162 x = array_ops.ones((2, 4)) 1163 with variable_scope.variable_scope("f", use_resource=True) as vs: 1164 y = F(x) 1165 all_vars = vs.global_variables() 1166 assert len(all_vars) == 1 1167 grads = gradients.gradients(y, [x, all_vars[0]]) 1168 for g in grads: 1169 self.assertIsNotNone(g) 1170 1171 self.evaluate(variables.global_variables_initializer()) 1172 dw = self.evaluate(math_ops.reduce_sum(grads[1])) 1173 self.assertEqual(12., dw) 1174 1175 def testCustomGradientWithCapture(self): 1176 with ops.Graph().as_default(): 1177 x = constant(3.) 1178 1179 @framework_function.Defun(dtypes.float32) 1180 def F(y): 1181 1182 @custom_gradient.custom_gradient 1183 def MyMultiply(x1, x2): 1184 result = x1 * x2 1185 1186 def Grad(dy): 1187 # Switched the ordering here. 1188 return [dy * x1, dy * x2] 1189 1190 return result, Grad 1191 1192 res = MyMultiply(x, y) 1193 return gradients.gradients(res, [y]) 1194 1195 y = constant(5.) 1196 dy = F(y) 1197 self.assertAllEqual(5., self.evaluate(dy)) 1198 1199 def testCustomGradientWithVariablesNoFalsePositives(self): 1200 1201 @custom_gradient.custom_gradient 1202 def F(x): 1203 out = core_layers.dense(x, 3, use_bias=False) 1204 1205 def Grad(out_grad, variables=None): # pylint: disable=redefined-outer-name 1206 self.assertEqual(1, len(variables)) # pylint: disable=g-generic-assert 1207 grads = gradients.gradients(out, [x, variables[0]], grad_ys=out_grad) 1208 return grads[0], [array_ops.ones((3, 3))] 1209 1210 return out, Grad 1211 1212 with ops.Graph().as_default(): 1213 with variable_scope.variable_scope("f", use_resource=True) as vs: 1214 a = array_ops.ones((2, 4)) 1215 1216 # Variabes in these layers shouldn't be picked up by the decorator. 1217 b = core_layers.dense(a, 3, use_bias=False) 1218 c = core_layers.dense(b, 3, use_bias=False) 1219 x = core_layers.dense(b, 3, use_bias=False) + c 1220 1221 # Only the variables used in F. 1222 y = F(x) 1223 1224 all_vars = vs.global_variables() 1225 assert len(all_vars) == 4 1226 grads = gradients.gradients(y, [x] + all_vars) 1227 _, var_grads = grads[0], grads[1:] 1228 for g in grads: 1229 self.assertIsNotNone(g) 1230 1231 self.evaluate(variables.global_variables_initializer()) 1232 dw = self.evaluate(math_ops.reduce_sum(var_grads[-1])) 1233 self.assertEqual(9., dw) 1234 1235 def testCustomGradientWithVariablesEager(self): 1236 with context.eager_mode(): 1237 layer = core_layers.Dense(4, use_bias=False) 1238 1239 @custom_gradient.custom_gradient 1240 def F(x): 1241 out = layer(x) 1242 1243 def Grad(out_grad, variables=None): # pylint: disable=redefined-outer-name 1244 del out_grad 1245 self.assertEqual(1, len(variables)) # pylint: disable=g-generic-assert 1246 return (array_ops.ones((3, 2)), 1247 [array_ops.ones((2, 4))]) 1248 1249 return out, Grad 1250 1251 x = array_ops.ones((3, 2)) + 2. 1252 with backprop.GradientTape() as tape: 1253 tape.watch(x) 1254 y = F(x) 1255 w, = layer.variables 1256 dx, dw = tape.gradient(y, [x, w]) 1257 self.assertEqual(6., math_ops.reduce_sum(dx).numpy()) 1258 self.assertEqual(8., math_ops.reduce_sum(dw).numpy()) 1259 1260 @test_util.run_v1_only("b/120545219") 1261 def testCustomGradientErrorsWithNonResourceVariables(self): 1262 1263 def F(x, use_resource=False): 1264 with variable_scope.variable_scope("f", use_resource=use_resource): 1265 out = core_layers.dense(x, 4, use_bias=False) 1266 1267 def Grad(out_grad, variables=None): # pylint: disable=redefined-outer-name 1268 del out_grad 1269 self.assertEqual(1, len(variables)) # pylint: disable=g-generic-assert 1270 return (array_ops.ones((3, 2)), [array_ops.ones((2, 4))]) 1271 1272 return out, Grad 1273 1274 @custom_gradient.custom_gradient 1275 def FResource(x): 1276 return F(x, use_resource=True) 1277 1278 @custom_gradient.custom_gradient 1279 def FNonResource(x): 1280 return F(x, use_resource=False) 1281 1282 x = array_ops.ones((3, 2)) + 2. 1283 1284 # Wrapping scope has use_resource=True but inner scope sets to False. Fails. 1285 with variable_scope.variable_scope("vs1", use_resource=True): 1286 with self.assertRaisesWithPredicateMatch(TypeError, 1287 "must be `ResourceVariable`s"): 1288 FNonResource(x) 1289 1290 # Wrapping scope has use_resource=False but inner scope sets to True. 1291 # Passes. 1292 with variable_scope.variable_scope("vs2", use_resource=False): 1293 FResource(x) 1294 1295 @parameterized.parameters(True, False) 1296 def testCustomGradientVariablesKwonlyArgs(self, anonymous_varargs): 1297 with context.eager_mode(): 1298 x_captured = variables.Variable(3.) # Used by FuncMult 1299 @custom_gradient.custom_gradient 1300 def FuncMult(x): 1301 def ActualGrad(dy, variables): # pylint: disable=redefined-outer-name 1302 self.assertLen(variables, 1) 1303 self.assertIs(variables[0], x_captured) 1304 x_captured_grad = 5. * x * dy 1305 return (4. * x_captured * dy, [x_captured_grad]) 1306 # Define the returned GradMult, using varargs; "variables" is kwonlyarg 1307 if anonymous_varargs: 1308 def GradMult(dy, *, variables=None): # pylint: disable=redefined-outer-name 1309 return ActualGrad(dy, variables) 1310 else: 1311 def GradMult(*dys, variables=None): # pylint: disable=redefined-outer-name 1312 return ActualGrad(dys[0], variables) 1313 1314 return x * x_captured, GradMult 1315 1316 x = variables.Variable(6.) 1317 with backprop.GradientTape(persistent=True) as g: 1318 y = FuncMult(x) 1319 self.assertAllEqual(g.gradient(y, x), 4. * 3.) 1320 1321 def testWithNumpyInputs(self): 1322 with context.eager_mode(): 1323 1324 @custom_gradient.custom_gradient 1325 def F(x): 1326 out = x 1327 1328 def Grad(_): 1329 return (None, None) 1330 1331 return out, Grad 1332 1333 x = np.ones((3, 2), dtype=np.float32) 1334 # Smoke test to ensure numpy inputs are accepted 1335 F(x) 1336 1337 @test_util.run_v1_only("b/120545219") 1338 def testRVGradientsDynamicCond(self): 1339 with self.cached_session(): 1340 alpha = resource_variable_ops.ResourceVariable( 1341 np.random.random((1,)), 1342 dtype="float32") 1343 1344 conditional = array_ops.placeholder_with_default(True, shape=()) 1345 output = control_flow_ops.cond( 1346 conditional, lambda: alpha * 2, lambda: alpha * 3) 1347 1348 g, = gradients_impl.gradients(output, alpha) 1349 self.evaluate(variables.global_variables_initializer()) 1350 self.assertAllEqual(g, [2.0]) 1351 self.assertAllEqual(g.eval(feed_dict={conditional: False}), [3.0]) 1352 1353 def testRecursiveCustomGradient(self): 1354 @custom_gradient.custom_gradient 1355 def F(x): 1356 out = core_layers.dense(x, 3, use_bias=False) 1357 1358 def Grad(out_grad, variables=None): # pylint: disable=redefined-outer-name 1359 self.assertEqual(1, len(variables)) # pylint: disable=g-generic-assert 1360 grads = gradients.gradients(out, [x, variables[0]], grad_ys=out_grad) 1361 return grads[0], [array_ops.ones((4, 3))] 1362 1363 return out, Grad 1364 1365 @custom_gradient.custom_gradient 1366 def DoubleF(x): 1367 out = F(x) 1368 1369 def Grad(out_grad, variables=None): # pylint: disable=redefined-outer-name 1370 self.assertEqual(1, len(variables)) # pylint: disable=g-generic-assert 1371 grads = gradients.gradients(out, [x, variables[0]], grad_ys=out_grad) 1372 return grads[0], [array_ops.ones((4, 3))] 1373 1374 return out, Grad 1375 with ops.Graph().as_default(): 1376 x = array_ops.ones((2, 4)) 1377 with variable_scope.variable_scope("f", use_resource=True) as vs: 1378 y = DoubleF(x) 1379 all_vars = vs.global_variables() 1380 assert len(all_vars) == 1 1381 grads = gradients.gradients(y, [x, all_vars[0]]) 1382 for g in grads: 1383 self.assertIsNotNone(g) 1384 1385 self.evaluate(variables.global_variables_initializer()) 1386 dw = self.evaluate(math_ops.reduce_sum(grads[1])) 1387 self.assertEqual(12., dw) 1388 1389 @parameterized.named_parameters( 1390 [(("_%s_%s" % (x_struct, y_struct)).replace(" ", "").replace("None", ""), # pylint: disable=g-complex-comprehension 1391 x_struct, y_struct) 1392 for y_struct in [[None, ()], (None, (), [], (None, ((), None)))] 1393 for x_struct in [(None, ()), (((), ()), [None, None], [], (None, ()))] 1394 ]) 1395 @test_util.run_in_graph_and_eager_modes 1396 def testCustomGradientStructuralInputOutput(self, x_struct, y_struct): 1397 """Tests that custom_gradient can handle structured inputs/outputs.""" 1398 def Zeros(x): 1399 return nest.map_structure(lambda _: array_ops.zeros([], "float32"), x) 1400 def GetStruct(x): 1401 return nest.map_structure(lambda _: None, x) 1402 1403 def MakeVjp(f, *x): 1404 with backprop.GradientTape(persistent=True) as tape: 1405 tape.watch(nest.flatten(x)) 1406 y = f(*x) 1407 def Vjp(dy): 1408 return tape.gradient(y, x, output_gradients=dy) 1409 return y, Vjp 1410 1411 @custom_gradient.custom_gradient 1412 def F(*x): 1413 self.assertEqual(x_struct, GetStruct(x)) 1414 def Vjp(*dy): 1415 self.assertEqual(len(nest.flatten(y_struct)), 1416 len(nest.flatten(dy))) 1417 return nest.flatten(Zeros(x_struct)) 1418 return Zeros(y_struct), Vjp 1419 1420 x, dy = Zeros([x_struct, y_struct]) 1421 y, vjp = MakeVjp(F, *x) 1422 dx = vjp(dy) 1423 self.assertEqual(x_struct, GetStruct(dx)) 1424 self.assertEqual(y_struct, GetStruct(y)) 1425 1426 1427class TensorListGradientsTest(test_util.TensorFlowTestCase): 1428 1429 def testDefaultGradYs(self): 1430 with ops.Graph().as_default(): 1431 tl = list_ops.empty_tensor_list( 1432 element_dtype=dtypes.float32, 1433 element_shape=ops.convert_to_tensor([], dtype=dtypes.int32)) 1434 a = constant(1.0) 1435 tl = list_ops.tensor_list_push_back(tl, a) 1436 1437 grad_tl = list_ops.empty_tensor_list( 1438 element_dtype=dtypes.float32, 1439 element_shape=ops.convert_to_tensor([], dtype=dtypes.int32)) 1440 grad_tl = list_ops.tensor_list_push_back(tl, constant(5.0)) 1441 1442 grad = gradients.gradients(tl, a, grad_ys=grad_tl)[0] 1443 1444 self.assertEqual(self.evaluate(grad), 5.) 1445 1446 1447class VariablesGradientTest(test_util.TensorFlowTestCase, 1448 parameterized.TestCase): 1449 1450 def _TestFnVariablesGradient(self, inputs, test_fn, vars_to_grad): 1451 """Returns gradients of `test_model` with respect to `vars_to_grad`.""" 1452 1453 test_fn_re = custom_gradient.recompute_grad(test_fn) 1454 1455 with backprop.GradientTape(persistent=True) as tape: 1456 tape.watch(vars_to_grad) 1457 out_re = test_fn_re(inputs, vars_to_grad) 1458 out = test_fn(inputs, vars_to_grad) 1459 1460 grads_re = tape.gradient(out_re, vars_to_grad) 1461 grads = tape.gradient(out, vars_to_grad) 1462 1463 return grads_re, grads 1464 1465 def _grad(self, f, argnums=0): 1466 """Return a function which computes the gradient of `f`.""" 1467 1468 def F(*params): 1469 with backprop.GradientTape() as tape: 1470 tape.watch(params) 1471 outputs = f(*params) 1472 return tape.gradient( 1473 outputs, 1474 params[argnums], 1475 unconnected_gradients=unconnected_gradients.UnconnectedGradients.ZERO) 1476 1477 return F 1478 1479 def _test_gradients(self, f, inputs, order, delta=1e-3, rtol=1e-2, atol=1e-6): 1480 """Tests backward jacobians of `f`'s [0, `order`)-order gradients.""" 1481 if order < 1: 1482 raise ValueError( 1483 "`order` should be a positive integer, got '{}'.".format(order)) 1484 if order > 1: 1485 self._test_gradients( 1486 f=self._grad(f), 1487 inputs=inputs, 1488 order=order - 1, 1489 delta=delta, 1490 rtol=rtol, 1491 atol=atol) 1492 sym_jac_back, num_jac = gradient_checker_v2.compute_gradient( 1493 f, inputs, delta=delta) 1494 self.assertAllClose(num_jac, sym_jac_back, rtol=rtol, atol=atol) 1495 1496 @test_util.run_v2_only 1497 def testCustomGradientRecomputeGradHigherOrder(self): 1498 1499 @custom_gradient.recompute_grad 1500 def F(x): 1501 return math_ops.reduce_prod(math_ops.tanh(x)**2) 1502 1503 self._test_gradients(F, [constant_op.constant([1.])], order=3) 1504 1505 @test_util.run_in_graph_and_eager_modes 1506 def testFnRecompute(self): 1507 """Checks that recompute_grad works grads of function args.""" 1508 1509 def TestFn(inputs, input_vars): 1510 return inputs * input_vars 1511 1512 def TestFnSeq(inputs, input_vars): 1513 return (inputs * input_vars, inputs * input_vars * 2.0) 1514 1515 with variable_scope.variable_scope("test", use_resource=True): 1516 test_var = variable_scope.get_variable( 1517 name="test_var", 1518 shape=10, 1519 trainable=True, 1520 ) 1521 self.evaluate(test_var.assign(np.ones([10]))) 1522 test_input = constant(np.ones((10, 10), dtype=np.float32)) 1523 1524 grads_re, grads = self._TestFnVariablesGradient(test_input, TestFn, 1525 test_input) 1526 1527 grads_re = self.evaluate(grads_re) 1528 grads = self.evaluate(grads) 1529 for g, g_re in zip(grads, grads_re): 1530 self.assertAllClose(g, g_re) 1531 1532 grads_re, grads = self._TestFnVariablesGradient(test_input, TestFn, 1533 test_var) 1534 grads_re = self.evaluate(grads_re) 1535 grads = self.evaluate(grads) 1536 for g, g_re in zip(grads, grads_re): 1537 self.assertAllClose(g, g_re) 1538 1539 # Regression test for wrapping sequence outputting functions. 1540 grads_re, grads = self._TestFnVariablesGradient(test_input, TestFnSeq, 1541 test_input) 1542 grads_re = self.evaluate(grads_re) 1543 grads = self.evaluate(grads) 1544 for g, g_re in zip(grads, grads_re): 1545 self.assertAllClose(g, g_re) 1546 1547 grads_re, grads = self._TestFnVariablesGradient(test_input, TestFnSeq, 1548 test_var) 1549 grads_re = self.evaluate(grads_re) 1550 grads = self.evaluate(grads) 1551 for g, g_re in zip(grads, grads_re): 1552 self.assertAllClose(g, g_re) 1553 1554 @parameterized.parameters(set((True, context.executing_eagerly()))) 1555 def testFnRecomputeWithScopeGradient(self, use_tape): 1556 """Checks that recompute_grad works with var scope and GradientTape.""" 1557 1558 def TestFn(input_t): 1559 with variable_scope.variable_scope("inner_scope"): 1560 test_var = variable_scope.get_variable( 1561 name="test_var", 1562 shape=10, 1563 trainable=True, 1564 ) 1565 return input_t * test_var 1566 1567 test_input_t = constant(np.zeros((10, 10), dtype=np.float32)) 1568 1569 with variable_scope.variable_scope( 1570 "output_scope", reuse=variable_scope.AUTO_REUSE, use_resource=True): 1571 test_fn_re = custom_gradient.recompute_grad(TestFn) 1572 1573 with test_util.AbstractGradientTape( 1574 use_tape=use_tape, persistent=True) as tape: 1575 out_re = test_fn_re(test_input_t) 1576 out = TestFn(test_input_t) 1577 1578 self.evaluate(variables.global_variables_initializer()) 1579 grads_re = tape.gradient(out_re, variables.trainable_variables()) 1580 grads = tape.gradient(out, variables.trainable_variables()) 1581 1582 grads_re = self.evaluate(grads_re) 1583 grads = self.evaluate(grads) 1584 for g, g_re in zip(grads, grads_re): 1585 self.assertAllClose(g, g_re) 1586 1587 @test_util.run_in_graph_and_eager_modes 1588 def testFnRecomputeSameTensor(self): 1589 """Check recompute_grad when wrapped f called as f(x, x) - b/147369366.""" 1590 1591 def TestFnMul(x, y): 1592 return x * y 1593 1594 def TestFnSingleVar(x, y): 1595 # pylint: disable=unused-argument 1596 return x 1597 1598 with variable_scope.variable_scope("test", use_resource=True): 1599 x = array_ops.ones((10)) 1600 1601 grads_re, grads = self._TestFnVariablesGradient(x, TestFnMul, 1602 x) 1603 grads_re = self.evaluate(grads_re) 1604 grads = self.evaluate(grads) 1605 for g, g_re in zip(grads, grads_re): 1606 self.assertAllClose(g, g_re) 1607 1608 grads_re, grads = self._TestFnVariablesGradient(x, TestFnSingleVar, 1609 x) 1610 grads_re = self.evaluate(grads_re) 1611 grads = self.evaluate(grads) 1612 for g, g_re in zip(grads, grads_re): 1613 self.assertAllClose(g, g_re) 1614 1615 1616class GradPassThroughTest(test_util.TensorFlowTestCase): 1617 1618 @test_util.run_v1_only("b/120545219") 1619 def test_gradients_v1(self): 1620 x = variable_scope.get_variable( 1621 name="x", shape=(), initializer=init_ops.constant_initializer(1.0), 1622 use_resource=True) 1623 z = variable_scope.get_variable( 1624 name="z", shape=(), initializer=init_ops.constant_initializer(3.0), 1625 use_resource=True) 1626 1627 # Verify that assign op is not differentiable 1628 y = state_ops.assign(x, z**2) 1629 grads = gradients.gradients(y, z) 1630 self.assertIsNone(grads[0]) 1631 1632 # Verify that when the (non differentiable) assign op is wrapped with 1633 # grad_pass_through, gradients are correctly forwarded to the inputs. 1634 # Form an input as quadratic function of variable z and check that the 1635 # gradient of output wrt to z is correct. 1636 y = custom_gradient.grad_pass_through( 1637 lambda v: state_ops.assign(x, v))(z**2) 1638 grads = gradients.gradients(y, z) 1639 1640 with self.cached_session(): 1641 self.evaluate(variables.global_variables_initializer()) 1642 self.assertAllClose(grads[0], 6.0) 1643 1644 # Verify that variables involved in the wrapped op do not receive gradients. 1645 y = custom_gradient.grad_pass_through(lambda v: x * v)(z) 1646 grads = gradients.gradients(y, x) 1647 self.assertIsNone(grads[0]) 1648 1649 @test_util.run_v2_only 1650 def test_gradients_v2(self): 1651 x = variables.Variable(1.0, name="x") 1652 z = variables.Variable(3.0, name="z") 1653 1654 # Verify that assign op is not differentiable 1655 with backprop.GradientTape() as tape: 1656 y = x.assign(z**2) 1657 grads = tape.gradient(y, z) 1658 self.assertIsNone(grads) 1659 1660 # Verify that when the (non differentiable) assign op is wrapped with 1661 # grad_pass_through, gradients are correctly forwarded to the inputs. 1662 # Form an input as quadratic function of variable z and check that the 1663 # gradient of output wrt to z is correct. 1664 with backprop.GradientTape() as tape: 1665 y = custom_gradient.grad_pass_through(x.assign)(z**2) 1666 grads = tape.gradient(y, z) 1667 self.assertAllClose(grads, 6.0) 1668 1669 # Verify that variables involved in the wrapped op do not receive gradients. 1670 with backprop.GradientTape() as tape: 1671 y = custom_gradient.grad_pass_through(lambda v: x * v)(z) 1672 grads = tape.gradient(y, x) 1673 self.assertIsNone(grads) 1674 1675 1676if __name__ == "__main__": 1677 googletest.main() 1678