1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for tensorflow.ops.gradients.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20import sys 21import warnings 22 23from absl.testing import parameterized 24import numpy as np 25from tensorflow.python.client import session 26from tensorflow.python.eager import backprop 27from tensorflow.python.eager import context 28from tensorflow.python.eager import function 29from tensorflow.python.framework import constant_op 30from tensorflow.python.framework import dtypes 31from tensorflow.python.framework import function as framework_function 32from tensorflow.python.framework import ops 33from tensorflow.python.framework import test_ops 34from tensorflow.python.framework import test_util 35from tensorflow.python.framework.constant_op import constant 36from tensorflow.python.keras.engine import training 37from tensorflow.python.layers import core as core_layers 38from tensorflow.python.ops import array_grad # pylint: disable=unused-import 39from tensorflow.python.ops import array_ops 40from tensorflow.python.ops import control_flow_grad # pylint: disable=unused-import 41from tensorflow.python.ops import control_flow_ops 42from tensorflow.python.ops import custom_gradient 43from tensorflow.python.ops import data_flow_grad # pylint: disable=unused-import 44from tensorflow.python.ops import data_flow_ops # pylint: disable=unused-import 45from tensorflow.python.ops import functional_ops # pylint: disable=unused-import 46from tensorflow.python.ops import gradients 47from tensorflow.python.ops import gradients_impl 48from tensorflow.python.ops import init_ops 49from tensorflow.python.ops import list_ops 50from tensorflow.python.ops import math_grad # pylint: disable=unused-import 51from tensorflow.python.ops import math_ops 52from tensorflow.python.ops import nn_grad # pylint: disable=unused-import 53from tensorflow.python.ops import resource_variable_ops 54from tensorflow.python.ops import state_grad # pylint: disable=unused-import 55from tensorflow.python.ops import state_ops 56from tensorflow.python.ops import tensor_array_grad # pylint: disable=unused-import 57from tensorflow.python.ops import tensor_array_ops 58from tensorflow.python.ops import unconnected_gradients 59from tensorflow.python.ops import variable_scope 60from tensorflow.python.ops import variables 61from tensorflow.python.ops.nn_ops import bias_add 62from tensorflow.python.platform import googletest 63 64 65class GradientsTest(test_util.TensorFlowTestCase, parameterized.TestCase): 66 67 def testGradients(self): 68 with ops.Graph().as_default(): 69 inp = constant(1.0, shape=[32, 100], name="in") 70 w = constant(1.0, shape=[100, 10], name="w") 71 b = constant(1.0, shape=[10], name="b") 72 xw = math_ops.matmul(inp, w, name="xw") 73 h = bias_add(xw, b, name="h") 74 w_grad = gradients.gradients(h, w)[0] 75 self.assertEquals("MatMul", w_grad.op.type) 76 self.assertEquals(w_grad.op._original_op, xw.op) 77 self.assertTrue(w_grad.op.get_attr("transpose_a")) 78 self.assertFalse(w_grad.op.get_attr("transpose_b")) 79 80 def testUnusedOutput(self): 81 with ops.Graph().as_default(): 82 w = constant(1.0, shape=[2, 2]) 83 x = constant(1.0, shape=[2, 2]) 84 wx = math_ops.matmul(w, x) 85 split_wx = array_ops.split(value=wx, num_or_size_splits=2, axis=0) 86 c = math_ops.reduce_sum(split_wx[1]) 87 gw = gradients.gradients(c, [w])[0] 88 self.assertEquals("MatMul", gw.op.type) 89 90 def testColocateGradients(self): 91 with ops.Graph().as_default() as g: 92 w = constant(1.0, shape=[1, 1]) 93 x = constant(1.0, shape=[1, 2]) 94 with g.device("/device:GPU:0"): 95 wx = math_ops.matmul(w, x) 96 gw = gradients.gradients(wx, [w], colocate_gradients_with_ops=True)[0] 97 self.assertEqual(gw.op.colocation_groups(), wx.op.colocation_groups()) 98 99 def testColocateGradientsWithAggregation(self): 100 with ops.Graph().as_default() as g: 101 with g.device("/device:GPU:1"): 102 w = constant(1.0, shape=[1, 1]) 103 x = constant(1.0, shape=[1, 2]) 104 y = constant(1.0, shape=[1, 2]) 105 wx = math_ops.matmul(w, x) 106 wy = math_ops.matmul(w, y) 107 with g.device("/device:GPU:0"): 108 z = wx + wy 109 110 gw1 = gradients.gradients(z, [w], colocate_gradients_with_ops=True)[0] 111 self.assertEqual(gw1.op.colocation_groups(), wx.op.colocation_groups()) 112 113 gw2 = gradients.gradients(z, [w], colocate_gradients_with_ops=False)[0] 114 self.assertTrue(wx.op.colocation_groups() != gw2.op.colocation_groups()) 115 116 def testColocateGradientsWithAggregationInMultipleDevices(self): 117 with ops.Graph().as_default() as g: 118 with g.device("/device:GPU:1"): 119 w = constant(1.0, shape=[1, 1]) 120 x = constant(1.0, shape=[1, 2]) 121 y = constant(1.0, shape=[1, 2]) 122 with g.device("/task:1"): 123 wx = math_ops.matmul(w, x) 124 with g.device("/task:2"): 125 wy = math_ops.matmul(w, y) 126 with g.device("/device:GPU:0"): 127 z = wx + wy 128 129 gw1 = gradients.gradients(z, [w], colocate_gradients_with_ops=True)[0] 130 self.assertEqual(gw1.op.colocation_groups(), w.op.colocation_groups()) 131 132 gw2 = gradients.gradients(z, [w], colocate_gradients_with_ops=False)[0] 133 self.assertTrue(w.op.colocation_groups() != gw2.op.colocation_groups()) 134 135 def testColocateGradientsWithGateGradients(self): 136 if not test_util.is_gpu_available(): 137 self.skipTest("No GPU available") 138 with ops.Graph().as_default() as g: 139 with g.device("/device:CPU:0"): 140 x = constant(1.0, shape=[1, 1]) 141 y = constant(1.0, shape=[1, 1]) 142 s = x + y 143 with g.device("/device:GPU:0"): 144 z = math_ops.reduce_sum(s) 145 146 gz_x = gradients.gradients(z, [x], colocate_gradients_with_ops=True, 147 gate_gradients=True)[0] 148 with session.Session(): 149 # Make sure the placer doesn't complain. 150 self.evaluate(gz_x) 151 152 def testBoundaryStop(self): 153 # Test that we don't differentiate 'x'. The gradient function for 'x' is 154 # set explicitly to None so we will get an exception if the gradient code 155 # tries to differentiate 'x'. 156 with ops.Graph().as_default(): 157 c = constant(1.0) 158 x = array_ops.identity(c) 159 y = x + 1.0 160 z = y + 1 161 grads = gradients.gradients(z, [x]) 162 self.assertTrue(all(x is not None for x in grads)) 163 164 @test_util.run_v1_only("b/120545219") 165 def testBoundaryContinue(self): 166 # Test that we differentiate both 'x' and 'y' correctly when x is a 167 # predecessor of y. 168 with self.cached_session(): 169 x = constant(1.0) 170 y = x * 2.0 171 z = y * 3.0 172 grads = gradients.gradients(z, [x, y]) 173 self.assertTrue(all(x is not None for x in grads)) 174 self.assertEqual(6.0, grads[0].eval()) 175 176 @test_util.run_v1_only("b/120545219") 177 def testAggregationMethodAccumulateN(self): 178 with self.cached_session(): 179 x = constant(1.0) 180 y = x * 2.0 181 z = y + y + y + y + y + y + y + y + y + y 182 grads = gradients.gradients( 183 z, [x, y], 184 aggregation_method=gradients.AggregationMethod. 185 EXPERIMENTAL_ACCUMULATE_N) 186 self.assertTrue(all(x is not None for x in grads)) 187 self.assertEqual(20.0, grads[0].eval()) 188 self.assertEqual(10.0, grads[1].eval()) 189 190 @test_util.run_v1_only("b/120545219") 191 def testAggregationMethodAddN(self): 192 with self.cached_session(): 193 x = constant(1.0) 194 y = x * 2.0 195 z = y + y + y + y + y + y + y + y + y + y 196 grads = gradients.gradients( 197 z, [x, y], aggregation_method=gradients.AggregationMethod.ADD_N) 198 self.assertTrue(all(x is not None for x in grads)) 199 self.assertEqual(20.0, grads[0].eval()) 200 self.assertEqual(10.0, grads[1].eval()) 201 202 @test_util.run_v1_only("b/120545219") 203 def testAggregationMethodTree(self): 204 with self.cached_session(): 205 x = constant(1.0) 206 y = x * 2.0 207 z = y + y + y + y + y + y + y + y + y + y 208 grads = gradients.gradients( 209 z, [x, y], 210 aggregation_method=gradients.AggregationMethod.EXPERIMENTAL_TREE) 211 self.assertTrue(all(x is not None for x in grads)) 212 self.assertEqual(20.0, grads[0].eval()) 213 self.assertEqual(10.0, grads[1].eval()) 214 215 def testNoGradientForStringOutputs(self): 216 with ops.Graph().as_default(): 217 218 def _TestOpGrad(_, float_grad, string_grad): 219 """Gradient function for TestStringOutput.""" 220 self.assertEquals(float_grad.dtype, dtypes.float32) 221 self.assertFalse(string_grad) 222 return float_grad 223 224 ops.RegisterGradient("TestStringOutput")(_TestOpGrad) 225 226 c = constant(1.0) 227 x, _ = test_ops.test_string_output(c) 228 z = x * 2.0 229 w = z * 3.0 230 grads = gradients.gradients(z, [c]) 231 self.assertIsInstance(grads[0], ops.Tensor) 232 grads = gradients.gradients(w, [c]) 233 self.assertIsInstance(grads[0], ops.Tensor) 234 235 def testNoGradientForStringOutputsWithOpNamespace(self): 236 with ops.Graph().as_default(): 237 238 def _TestOpGrad(_, float_grad, string_grad): 239 """Gradient function for TestStringOutput.""" 240 self.assertEqual(float_grad.dtype, dtypes.float32) 241 self.assertFalse(string_grad) 242 return float_grad 243 244 ops.RegisterGradient("Namespace>TestStringOutput")(_TestOpGrad) 245 246 c = constant(1.0) 247 x, _ = test_ops.namespace_test_string_output(c) 248 z = x * 2.0 249 w = z * 3.0 250 grads = gradients.gradients(z, [c]) 251 self.assertIsInstance(grads[0], ops.Tensor) 252 grads = gradients.gradients(w, [c]) 253 self.assertIsInstance(grads[0], ops.Tensor) 254 255 def testSingletonIndexedSlices(self): 256 with ops.Graph().as_default(): 257 x = array_ops.placeholder(dtypes.float32) 258 y = array_ops.identity(x) 259 dy = ops.IndexedSlices( 260 array_ops.placeholder(dtypes.float32), 261 array_ops.placeholder(dtypes.int32)) 262 dx, = gradients.gradients(y, x, grad_ys=dy) 263 # The IndexedSlices gradient of tf.identity is the identity map. 264 with self.cached_session() as sess: 265 vdx, vdy = sess.run( 266 [dx, dy], feed_dict={x: [1.0], dy.indices: [0], dy.values: [2.0]}) 267 self.assertEqual(vdx, vdy) 268 269 @test_util.run_v1_only("b/120545219") 270 def testNonDifferentiableSwitchInWhileLoop(self): 271 with ops.Graph().as_default(): 272 v = array_ops.placeholder(dtypes.float32, []) 273 274 def _Step(i, a, ta): 275 a += math_ops.cast(v, dtypes.int32) 276 return (i + 1, a, ta.write(i, a)) 277 278 n = 4 279 i, _, ta = control_flow_ops.while_loop( 280 lambda i, *_: i < n, 281 _Step, [0, 0, tensor_array_ops.TensorArray( 282 dtypes.int32, size=n)]) 283 target = ta.read(i - 1) 284 grad, = gradients.gradients(target, v) 285 self.assertIsNone(grad) 286 287 def testVariableReadValueGradient(self): 288 with ops.Graph().as_default(): 289 init = constant_op.constant(100.0) 290 var = variables.Variable(init) 291 gradient = gradients.gradients(var.read_value(), var) 292 self.assertIsNotNone(gradient) 293 294 @parameterized.parameters(dtypes.float32, dtypes.float64) 295 def testVariableDefaultGrad(self, dtype): 296 with ops.Graph().as_default(): 297 init = constant_op.constant(100.0, dtype=dtype) 298 var = variables.Variable(init) 299 dummy_const = constant_op.constant(0.0) 300 gradient = gradients.gradients( 301 dummy_const, 302 var, 303 unconnected_gradients=unconnected_gradients.UnconnectedGradients.ZERO 304 )[0] 305 self.assertEqual(gradient.dtype, dtype) 306 self.assertIsNotNone(gradient) 307 308 def testVariableAsGraphElementGradient(self): 309 with ops.Graph().as_default() as graph: 310 init = constant_op.constant(100.0) 311 var = variables.Variable(init) 312 gradient = gradients.gradients(graph.as_graph_element(var), var) 313 self.assertIsNotNone(gradient) 314 315 @test_util.run_v1_only("b/120545219") 316 def testVariableRefGradient(self): 317 with ops.Graph().as_default(): 318 init = constant_op.constant(100.0) 319 var = variables.VariableV1(init) 320 gradient = gradients.gradients(var._ref(), var) 321 self.assertIsNotNone(gradient) 322 323 @test_util.run_v1_only("b/120545219") 324 def testDependentYs(self): 325 with self.cached_session(): 326 x = constant_op.constant(3.0) 327 y = math_ops.square(x) 328 y1 = math_ops.square(y) 329 y2 = math_ops.square(y1) 330 g = gradients.gradients([y, y2], x) 331 self.assertAllClose(17502.0, g[0].eval()) 332 g = gradients.gradients(y + y2, x) 333 self.assertAllClose(17502.0, g[0].eval()) 334 z = array_ops.identity(y) 335 z2 = array_ops.identity(y2) 336 g = gradients.gradients([z, z2], x) 337 self.assertAllClose(17502.0, g[0].eval()) 338 339 @test_util.run_v1_only("b/120545219") 340 def testPartialDerivatives(self): 341 with self.cached_session(): 342 x = constant_op.constant(1.) 343 y = 2 * x 344 z = x + y 345 totalg = gradients.gradients(z, [x, y]) 346 self.assertEqual([3.0, 1.0], [g.eval() for g in totalg]) 347 partialg = gradients.gradients(z, [x, y], stop_gradients=[x, y]) 348 self.assertEqual([1.0, 1.0], [g.eval() for g in partialg]) 349 350 @test_util.run_v1_only("b/120545219") 351 def testStopGradients(self): 352 def _MakeGraph(rng, stop_gradients=()): 353 def _FunctionOf(xs, k=3): 354 return ops.convert_to_tensor( 355 sum(math_ops.matmul(rng.rand(k, k), x) for x in xs) 356 + rng.rand(k, k)) 357 358 a = _FunctionOf([]) 359 if "a" in stop_gradients: a = array_ops.stop_gradient(a) 360 b = _FunctionOf([a]) 361 if "b" in stop_gradients: b = array_ops.stop_gradient(b) 362 c = _FunctionOf([a, b]) 363 if "c" in stop_gradients: c = array_ops.stop_gradient(c) 364 d = _FunctionOf([b, c]) 365 if "d" in stop_gradients: d = array_ops.stop_gradient(d) 366 return dict(a=a, b=b, c=c, d=d) 367 368 def _Gradients(ys, xs, **kwargs): 369 dydxs = gradients.gradients(ys, xs, **kwargs) 370 dydxs = [0. * x if dydx is None else dydx 371 for x, dydx in zip(xs, dydxs)] 372 return dydxs 373 374 seed = np.random.randint(1000) 375 cases = [] 376 subsets = [""] + "a b c d ab ac ad bc bd cd abc abd acd bcd abcd".split() 377 graph = _MakeGraph(np.random.RandomState(seed)) 378 for constants in subsets: 379 graph_with_stops = _MakeGraph(np.random.RandomState(seed), constants) 380 for variables_ in subsets: 381 # compute the gradient when stopped using tf.stop_gradients 382 grad1 = _Gradients([graph_with_stops["d"]], 383 [graph_with_stops[v] for v in variables_]) 384 # compute the gradient when stopped using the stop_gradients kwarg 385 grad2 = _Gradients([graph["d"]], 386 [graph[v] for v in variables_], 387 stop_gradients=[graph[v] for v in constants]) 388 cases.append(dict(grad1=grad1, grad2=grad2, 389 constants=constants, variables=variables_)) 390 391 # evaluate all tensors in one call to session.run for speed 392 with self.cached_session() as sess: 393 results = sess.run([(case["grad1"], case["grad2"]) for case in cases]) 394 395 for (npgrad1, npgrad2), case in zip(results, cases): 396 for a, b in zip(npgrad1, npgrad2): 397 np.testing.assert_allclose(a, b) 398 399 def testUnconnectedGradientsNoneUnconnectedGradients(self): 400 with ops.Graph().as_default(): 401 x = constant(1.0, shape=[2, 2]) 402 y = constant(3.0, shape=[3, 1]) 403 grad = gradients.gradients( 404 [y], [x], unconnected_gradients="none") 405 self.assertIsNone(grad[0]) 406 407 def testUnconnectedGradientsZerosUnconnectedGradients(self): 408 with ops.Graph().as_default(): 409 x = constant(1.0, shape=[2, 2]) 410 y = constant(3.0, shape=[3, 1]) 411 grads = gradients.gradients( 412 [y], [x], unconnected_gradients="zero") 413 with self.cached_session() as sess: 414 self.assertAllEqual([[0.0, 0.0], [0.0, 0.0]], self.evaluate(grads)[0]) 415 416 def testUnconnectedGradientsZeroConnectedGradients(self): 417 with ops.Graph().as_default(): 418 x = constant(1.0) 419 y = x * 3.0 420 grad = gradients.gradients( 421 [y], [x], unconnected_gradients="zero") 422 with self.cached_session() as sess: 423 self.assertEquals(3.0, self.evaluate(grad)[0]) 424 425 def testUnknownUnconnectedGradientsValueGiven(self): 426 with ops.Graph().as_default(): 427 x = constant(1.0) 428 y = constant(1.0) 429 with self.assertRaisesRegexp( 430 ValueError, "Unknown value for unconnected_gradients: 'nonsense'"): 431 gradients.gradients([y], [x], unconnected_gradients="nonsense") 432 433 434class FunctionGradientsTest(test_util.TensorFlowTestCase): 435 436 @classmethod 437 def XSquarePlusB(cls, x, b): 438 return x * x + b 439 440 @classmethod 441 def XSquarePlusBGradient(cls, x, b, g): 442 # Perturb gradients (multiply by 2), so we can test that this was called. 443 g *= 2.0 444 return g * 2.0 * x, g 445 446 @classmethod 447 def _PythonGradient(cls, op, grad): 448 # Perturb gradients (multiply by 3), so we can test that this was called. 449 grad *= 3.0 450 return grad * op.inputs[0] * 2.0, grad 451 452 @classmethod 453 def _GetFunc(cls, **kwargs): 454 return framework_function.Defun(dtypes.float32, dtypes.float32, ** 455 kwargs)(cls.XSquarePlusB) 456 457 def _GetFuncGradients(self, f, x_value, b_value): 458 x = constant_op.constant(x_value, name="x") 459 b = constant_op.constant(b_value, name="b") 460 461 y = f(x, b) 462 grads = gradients.gradients(y, [x, b]) 463 with self.cached_session() as sess: 464 return sess.run(grads) 465 466 def testFunctionGradientsBasic(self): 467 g = ops.Graph() 468 with g.as_default(): 469 f = self._GetFunc() 470 # Get gradients (should add SymbolicGradient node for function). 471 grads = self._GetFuncGradients(f, [2.0], [1.0]) 472 self.assertAllEqual([4.0], grads[0]) 473 self.assertAllEqual([1.0], grads[1]) 474 475 def testFunctionGradientsComposition(self): 476 with ops.Graph().as_default(): 477 f = self._GetFunc() 478 x = constant_op.constant([2.0], name="x") 479 b1 = constant_op.constant([1.0], name="b1") 480 b2 = constant_op.constant([1.0], name="b2") 481 482 y = f(f(x, b1), b2) 483 # Build gradient graph (should add SymbolicGradient node for function). 484 grads = gradients.gradients(y, [x, b1]) 485 486 with self.cached_session() as sess: 487 self.assertAllEqual([40.0], self.evaluate(grads)[0]) 488 self.assertAllEqual([10.0], self.evaluate(grads)[1]) 489 490 def testFunctionGradientsWithGradFunc(self): 491 g = ops.Graph() 492 with g.as_default(): 493 grad_func = framework_function.Defun(dtypes.float32, dtypes.float32, 494 dtypes.float32)( 495 self.XSquarePlusBGradient) 496 f = self._GetFunc(grad_func=grad_func) 497 # Get gradients (should add SymbolicGradient node for function, which 498 # uses the grad_func above, which multiplies all gradients by 2). 499 grads = self._GetFuncGradients(f, [2.0], [1.0]) 500 self.assertAllEqual([4.0 * 2], grads[0]) 501 self.assertAllEqual([1.0 * 2], grads[1]) 502 503 def testFunctionGradientWithRegistration(self): 504 g = ops.Graph() 505 with g.as_default(): 506 f = self._GetFunc(python_grad_func=self._PythonGradient) 507 # Get gradients, using the python gradient function. It multiplies the 508 # gradients by 3. 509 grads = self._GetFuncGradients(f, [2.0], [1.0]) 510 self.assertAllEqual([4.0 * 3], grads[0]) 511 self.assertAllEqual([1.0 * 3], grads[1]) 512 513 def testFunctionGradientWithGradFuncAndRegistration(self): 514 g = ops.Graph() 515 with g.as_default(): 516 grad_func = framework_function.Defun(dtypes.float32, dtypes.float32, 517 dtypes.float32)( 518 self.XSquarePlusBGradient) 519 with self.assertRaisesRegexp(ValueError, "Gradient defined twice"): 520 f = self._GetFunc( 521 grad_func=grad_func, python_grad_func=self._PythonGradient) 522 f.add_to_graph(ops.Graph()) 523 524 def testGradientWrtCaptured(self): 525 with ops.Graph().as_default(): 526 x = constant_op.constant(1.0, name="x") 527 528 @function.defun() 529 def Foo(): 530 y = math_ops.multiply(x, 2.0, name="y") 531 g = gradients_impl.gradients(y, x) 532 return g[0] 533 534 f = Foo() 535 with self.cached_session() as sess: 536 self.assertEqual(self.evaluate(f), 2.0) 537 538 def testGradientOfCaptured(self): 539 with ops.Graph().as_default(): 540 x = constant_op.constant(1.0, name="x") 541 y = math_ops.multiply(x, 2.0, name="y") 542 543 @framework_function.Defun() 544 def Foo(): 545 g = gradients_impl.gradients(y, x) 546 return g[0] 547 548 f = Foo() 549 with self.cached_session() as sess: 550 self.assertEqual(self.evaluate(f), 2.0) 551 552 def testCapturedResourceVariable(self): 553 with ops.Graph().as_default(): 554 var = resource_variable_ops.ResourceVariable(1.0, name="var") 555 556 @function.defun() 557 def Foo(): 558 y = math_ops.multiply(var, 2.0, name="y") 559 g = gradients_impl.gradients(y, var) 560 return g[0] 561 562 f = Foo() 563 with self.cached_session() as sess: 564 self.evaluate(variables.global_variables_initializer()) 565 self.assertEqual(self.evaluate(f), 2.0) 566 567 def testCapturedNested(self): 568 with ops.Graph().as_default(): 569 x1 = constant_op.constant(1.0, name="x1") 570 x2 = constant_op.constant(2.0, name="x2") 571 x3 = math_ops.multiply(x1, x2, name="x3") 572 573 @function.defun() 574 def Outer(): 575 outer1 = array_ops.identity(x1, name="outer1") 576 577 @function.defun() 578 def Inner(): 579 inner1 = array_ops.identity(outer1, name="inner1") 580 inner2 = array_ops.identity(x2, name="inner2") 581 inner3 = array_ops.identity(x3, name="inner3") 582 return gradients_impl.gradients([inner1, inner2, inner3, x1], 583 [x1, x2]) 584 585 return Inner() 586 587 x1_grad, x2_grad = Outer() 588 with self.cached_session() as sess: 589 # 1.0 + None + 2.0 + 1.0 = 4.0 590 self.assertEqual(self.evaluate(x1_grad), 4.0) 591 # None + 1.0 + 1.0 + None = 2.0 592 self.assertEqual(self.evaluate(x2_grad), 2.0) 593 594 def testCapturedFromFunction(self): 595 with ops.Graph().as_default(): 596 x = constant_op.constant(1.0, name="x") 597 598 @function.defun() 599 def Outer(): 600 y = math_ops.multiply(x, 2.0, name="y") 601 602 @function.defun() 603 def Inner(): 604 z = math_ops.multiply(y, 3.0, name="z") 605 g = gradients_impl.gradients(z, y) 606 return g[0] 607 608 return Inner() 609 610 z_grad = Outer() 611 with self.cached_session() as sess: 612 self.assertEqual(self.evaluate(z_grad), 3.0) 613 614 def testCapturedEagerTensors(self): 615 # Test that we can handle captured eager tensors unrelated to the gradient 616 # computation (i.e. we need to ignore them). 617 # TODO(skyewm): make it an error if you try to take the gradient wrt a 618 # captured EagerTensor 619 with context.eager_mode(): 620 c = constant_op.constant(2.0, name="c") 621 622 @function.defun 623 def Foo(): 624 x = constant_op.constant(10.0, name="x") 625 y = math_ops.multiply(x, c, name="y") 626 # Regression test for b/122564611. 627 z = math_ops.multiply(c, y, name="z") 628 g = gradients_impl.gradients(z, x) 629 return g[0] 630 631 self.assertEqual(Foo().numpy(), 4.0) 632 633 634class StopGradientTest(test_util.TensorFlowTestCase): 635 636 def testStopGradient(self): 637 with ops.Graph().as_default(): 638 inp = constant(1.0, shape=[100, 32], name="in") 639 out = array_ops.stop_gradient(inp) 640 igrad = gradients.gradients(out, inp)[0] 641 assert igrad is None 642 643 644class PreventGradientTest(test_util.TensorFlowTestCase): 645 646 def testPreventGradient(self): 647 with ops.Graph().as_default(): 648 inp = constant(1.0, shape=[100, 32], name="in") 649 out = array_ops.prevent_gradient(inp) 650 with self.assertRaisesRegexp(LookupError, "explicitly disabled"): 651 _ = gradients.gradients(out, inp) 652 653 654class HessianVectorProductTest(test_util.TensorFlowTestCase): 655 656 @test_util.run_v1_only("b/120545219") 657 def testHessianVectorProduct(self): 658 # Manually compute the Hessian explicitly for a low-dimensional problem 659 # and check that HessianVectorProduct matches multiplication by the 660 # explicit Hessian. 661 # Specifically, the Hessian of f(x) = x^T A x is 662 # H = A + A^T. 663 # We expect HessianVectorProduct(f(x), x, v) to be H v. 664 m = 4 665 rng = np.random.RandomState([1, 2, 3]) 666 mat_value = rng.randn(m, m).astype("float32") 667 v_value = rng.randn(m, 1).astype("float32") 668 x_value = rng.randn(m, 1).astype("float32") 669 hess_value = mat_value + mat_value.T 670 hess_v_value = np.dot(hess_value, v_value) 671 for use_gpu in [False, True]: 672 with self.cached_session(use_gpu=use_gpu): 673 mat = constant_op.constant(mat_value) 674 v = constant_op.constant(v_value) 675 x = constant_op.constant(x_value) 676 mat_x = math_ops.matmul(mat, x, name="Ax") 677 x_mat_x = math_ops.matmul(array_ops.transpose(x), mat_x, name="xAx") 678 hess_v = gradients_impl._hessian_vector_product(x_mat_x, [x], [v])[0] 679 hess_v_actual = self.evaluate(hess_v) 680 self.assertAllClose(hess_v_value, hess_v_actual) 681 682 683class HessianTest(test_util.TensorFlowTestCase): 684 685 @test_util.run_v1_only("b/120545219") 686 def testHessian1D(self): 687 # Manually compute the Hessian explicitly for a low-dimensional problem 688 # and check that `hessian` matches. Specifically, the Hessian of 689 # f(x) = x^T A x is H = A + A^T. 690 m = 4 691 rng = np.random.RandomState([1, 2, 3]) 692 mat_value = rng.randn(m, m).astype("float32") 693 x_value = rng.randn(m).astype("float32") 694 hess_value = mat_value + mat_value.T 695 with self.session(use_gpu=True): 696 mat = constant_op.constant(mat_value) 697 x = constant_op.constant(x_value) 698 x_mat_x = math_ops.reduce_sum(x[:, None] * mat * x[None, :]) 699 hess = gradients.hessians(x_mat_x, x)[0] 700 hess_actual = self.evaluate(hess) 701 self.assertAllClose(hess_value, hess_actual) 702 703 @test_util.run_v1_only("b/120545219") 704 def testHessian1D_multi(self): 705 # Test the computation of the hessian with respect to multiple tensors 706 m = 4 707 n = 3 708 rng = np.random.RandomState([1, 2, 3]) 709 mat_values = [rng.randn(m, m).astype("float32") for _ in range(n)] 710 x_values = [rng.randn(m).astype("float32") for _ in range(n)] 711 hess_values = [mat_value + mat_value.T for mat_value in mat_values] 712 with self.session(use_gpu=True): 713 mats = [constant_op.constant(mat_value) for mat_value in mat_values] 714 xs = [constant_op.constant(x_value) for x_value in x_values] 715 xs_mats_xs = [ 716 math_ops.reduce_sum(x[:, None] * mat * x[None, :]) 717 for x, mat in zip(xs, mats) 718 ] 719 hessians = gradients.hessians(xs_mats_xs, xs) 720 hessians_actual = [hess.eval() for hess in hessians] 721 for hess_value, hess_actual in zip(hess_values, hessians_actual): 722 self.assertAllClose(hess_value, hess_actual) 723 724 @test_util.run_v1_only("b/120545219") 725 def testHessianInvalidDimension(self): 726 for shape in [(10, 10), None]: 727 with self.cached_session(use_gpu=True): 728 x = array_ops.placeholder(dtypes.float32, shape) 729 # Expect a ValueError because the dimensions are wrong 730 with self.assertRaises(ValueError): 731 gradients.hessians(x, x) 732 733 @test_util.run_v1_only("b/120545219") 734 def testHessian2D_square_matrix(self): 735 # Manually compute the Hessian explicitly for a low-dimensional problem 736 # and check that `hessian` matches. Specifically, the Hessian of 737 # f(x) = 1/2 * x^T * x is H = constant (block identity matrix) 738 m = 3 739 rng = np.random.RandomState([1, 2, 3]) 740 x_value = rng.randn(m, m).astype("float32") 741 with self.session(use_gpu=True): 742 x = constant_op.constant(x_value) 743 x_square = math_ops.reduce_sum( 744 math_ops.matmul(array_ops.transpose(x), x) * 0.5 745 ) 746 hess = gradients.hessians(x_square, x)[0] 747 hess_actual = self.evaluate(hess) 748 hess_value = np.bmat([ 749 [elem*np.ones((m, m)) for elem in vec] 750 for vec in np.eye(m) 751 ]).astype("float32") 752 self.assertAllEqual((m, m, m, m), hess_actual.shape) 753 self.assertAllClose(hess_value, hess_actual.reshape((m * m, m * m))) 754 755 @test_util.run_v1_only("b/120545219") 756 def testHessian2D_non_square_matrix(self): 757 m = 3 758 n = 4 759 rng = np.random.RandomState([1, 2, 3]) 760 x_value = rng.randn(m, n).astype("float32") 761 with self.session(use_gpu=True): 762 x = constant_op.constant(x_value) 763 x_square = math_ops.reduce_sum( 764 math_ops.matmul(array_ops.transpose(x), x) * 0.5 765 ) 766 hess = gradients.hessians(x_square, x)[0] 767 hess_actual = self.evaluate(hess) 768 hess_value = np.bmat([ 769 [elem*np.ones((n, n)) for elem in vec] 770 for vec in np.eye(m) 771 ]).astype("float32") 772 self.assertAllEqual((m, n, m, n), hess_actual.shape) 773 self.assertAllClose(hess_value, hess_actual.reshape((m * n, m * n))) 774 775 776class IndexedSlicesToTensorTest(test_util.TensorFlowTestCase): 777 778 @test_util.run_v1_only("b/120545219") 779 def testIndexedSlicesToTensor(self): 780 with self.cached_session(): 781 np_val = np.random.rand(4, 4, 4, 4).astype(np.float32) 782 c = constant_op.constant(np_val) 783 c_sparse = math_ops._as_indexed_slices(c) 784 self.assertAllEqual(np_val.shape, c_sparse.dense_shape.eval()) 785 c_dense = math_ops.multiply(c_sparse, 1.0) 786 self.assertAllClose(np_val, self.evaluate(c_dense)) 787 788 @test_util.run_v1_only("b/120545219") 789 def testIndexedSlicesToTensorList(self): 790 with self.cached_session(): 791 numpy_list = [] 792 dense_list = [] 793 sparse_list = [] 794 for _ in range(3): 795 np_val = np.random.rand(4, 4, 4, 4).astype(np.float32) 796 c = constant_op.constant(np_val) 797 c_sparse = math_ops._as_indexed_slices(c) 798 numpy_list.append(np_val) 799 dense_list.append(c) 800 sparse_list.append(c_sparse) 801 packed_dense = array_ops.stack(dense_list) 802 packed_sparse = array_ops.stack(sparse_list) 803 self.assertAllClose(packed_dense.eval(), self.evaluate(packed_sparse)) 804 805 @test_util.run_v1_only("b/120545219") 806 def testInt64Indices(self): 807 with self.cached_session(): 808 np_val = np.random.rand(4, 4, 4, 4).astype(np.float32) 809 c = constant_op.constant(np_val) 810 c_sparse = math_ops._as_indexed_slices(c) 811 c_sparse = ops.IndexedSlices( 812 c_sparse.values, 813 math_ops.cast(c_sparse.indices, dtypes.int64), c_sparse.dense_shape) 814 self.assertAllEqual(np_val.shape, c_sparse.dense_shape.eval()) 815 c_dense = math_ops.multiply(c_sparse, 1.0) 816 self.assertAllClose(np_val, self.evaluate(c_dense)) 817 818 @test_util.run_v1_only("b/120545219") 819 def testWarnings(self): 820 # TODO(gunan) Reenable after this issue is fixed: 821 # https://github.com/google/protobuf/issues/2812 822 if sys.version_info >= (3, 5): 823 self.skipTest("Skipped test for Python 3.5+") 824 825 # Smaller than the threshold: no warning. 826 c_sparse = ops.IndexedSlices( 827 array_ops.placeholder(dtypes.float32), 828 array_ops.placeholder(dtypes.int32), constant([4, 4, 4, 4])) 829 with warnings.catch_warnings(record=True) as w: 830 math_ops.multiply(c_sparse, 1.0) 831 self.assertEqual(0, len(w)) 832 833 # Greater than or equal to the threshold: warning. 834 c_sparse = ops.IndexedSlices( 835 array_ops.placeholder(dtypes.float32), 836 array_ops.placeholder(dtypes.int32), constant([100, 100, 100, 100])) 837 # "always" filter prevents the warning from being suppressed if it was 838 # already triggered in a different test. 839 warnings.simplefilter("always") 840 with warnings.catch_warnings(record=True) as w: 841 math_ops.multiply(c_sparse, 1.0) 842 self.assertEqual(1, len(w)) 843 self.assertTrue( 844 "with 100000000 elements. This may consume a large amount of memory." in 845 str(w[0].message)) 846 847 # Unknown dense shape: warning. 848 c_sparse = ops.IndexedSlices( 849 array_ops.placeholder(dtypes.float32), 850 array_ops.placeholder(dtypes.int32), 851 array_ops.placeholder(dtypes.int32)) 852 with warnings.catch_warnings(record=True) as w: 853 math_ops.multiply(c_sparse, 1.0) 854 self.assertEqual(1, len(w)) 855 self.assertTrue( 856 "of unknown shape. This may consume a large amount of memory." in 857 str(w[0].message)) 858 859 860class OnlyRealGradientsTest(test_util.TensorFlowTestCase): 861 862 @test_util.run_v1_only("b/120545219") 863 def testRealOnly(self): 864 x = constant_op.constant(7+3j, dtype=dtypes.complex64) 865 y = math_ops.square(x) 866 with self.assertRaisesRegexp( 867 TypeError, 868 r"Gradients of complex tensors must set grad_ys " 869 r"\(y\.dtype = tf\.complex64\)"): 870 gradients.gradients(y, x) 871 872 873class ResourceCondTest(test_util.TensorFlowTestCase): 874 875 @test_util.run_v1_only("b/120545219") 876 def testBasic(self): 877 gamma = resource_variable_ops.ResourceVariable( 878 np.random.random((3,)), 879 dtype="float32", name="gamma") 880 881 inputs = array_ops.ones(shape=(3,), dtype="float32") 882 883 def TestFn(): 884 output = inputs + gamma 885 return output 886 887 training = array_ops.placeholder_with_default(True, shape=()) 888 output = control_flow_ops.cond( 889 training, TestFn, lambda: inputs) 890 891 loss = output 892 893 grads = gradients.gradients( 894 loss, [gamma]) 895 self.assertTrue(None not in grads) 896 897 898class GetDependentVariablesTest(test_util.TensorFlowTestCase): 899 900 def testNoVariables(self): 901 with ops.Graph().as_default(): 902 func = lambda x: array_ops.identity(x) + 5.0 903 input_t = constant_op.constant(2.0) 904 result_t = func(input_t) 905 dependent_vars = custom_gradient.get_dependent_variables( 906 [input_t], [result_t]) 907 908 # There are no variables. 909 self.assertEqual(dependent_vars, []) 910 911 def testVariablesOutside(self): 912 with ops.Graph().as_default(): 913 init = constant_op.constant(100.0) 914 var = variables.Variable(init) 915 916 # The variable is closed over. It should be found. 917 func = lambda x: array_ops.identity(x) + 5.0 + var 918 919 input_t = constant_op.constant(2.0) 920 result_t = func(input_t) 921 dependent_vars = custom_gradient.get_dependent_variables( 922 [input_t], [result_t]) 923 self.assertEqual(dependent_vars, [var]) 924 925 def testVariableSamePrefix(self): 926 with ops.Graph().as_default(): 927 var_name = "my_variable" 928 v_z = variable_scope.get_variable(var_name, shape=()) 929 v_o = variable_scope.get_variable(var_name + "_ones", shape=()) 930 931 # The variable is closed over. It should be found. 932 func = lambda x: array_ops.identity(x) + 5.0 + v_z + v_o 933 934 input_t = constant_op.constant(2.0) 935 result_t = func(input_t) 936 dependent_vars = custom_gradient.get_dependent_variables( 937 [input_t], [result_t]) 938 self.assertEqual(set(dependent_vars), set([v_o, v_z])) 939 940 def testVariablesOutsideButDSeparated(self): 941 with ops.Graph().as_default(): 942 init = constant_op.constant(100.0) 943 var = variables.Variable(init) 944 945 # The variable is d-separated by the inputs. It should not be found. 946 input_t = array_ops.identity(var) * 5.0 947 948 func = lambda x: array_ops.identity(x) + 5.0 949 result_t = func(input_t) 950 dependent_vars = custom_gradient.get_dependent_variables( 951 [input_t], [result_t]) 952 self.assertEqual(dependent_vars, []) 953 954 def testVariablesOutsideAndNonDifferentiable(self): 955 with ops.Graph().as_default(): 956 init = constant_op.constant(100.0, shape=(5,)) 957 var = variables.Variable(init, shape=(5,)) 958 959 def _Func(x): 960 # non-differentiable dependency on var. 961 # the variable should not be found. 962 y = array_ops.ones_like(var) 963 return array_ops.identity(x) + 5.0 + y 964 965 input_t = constant_op.constant(2.0) 966 result_t = _Func(input_t) 967 dependent_vars = custom_gradient.get_dependent_variables( 968 [input_t], [result_t]) 969 self.assertEqual(dependent_vars, []) 970 971 def testVariablesOutsideAndNonTrainable(self): 972 with ops.Graph().as_default(): 973 init = constant_op.constant(100.0, shape=(5,)) 974 975 # Both variables are used in the function but only the trainable one 976 # should be found. 977 var_trainable = variables.Variable(init, shape=(5,)) 978 var_nontrainable = variables.Variable(init, shape=(5,), trainable=False) 979 980 def _Func(x): 981 del x 982 return var_trainable + var_nontrainable 983 984 input_t = constant_op.constant(2.0) 985 result_t = _Func(input_t) 986 dependent_vars = custom_gradient.get_dependent_variables( 987 [input_t], [result_t]) 988 self.assertEqual(dependent_vars, [var_trainable]) 989 990 def testNesting(self): 991 with ops.Graph().as_default(): 992 init = constant_op.constant(100.0, shape=(5,)) 993 var = variables.Variable(init, shape=(5,)) 994 995 def _Func(inputs): 996 x = inputs["x"] 997 result = array_ops.identity(x) + 5.0 + var 998 return { 999 "y": result 1000 } 1001 1002 input_t = constant_op.constant(2.0) 1003 func_inputs = { 1004 "x": input_t 1005 } 1006 result_t = _Func(func_inputs) 1007 1008 # Ensure we can deal with dictionary input and output. 1009 dependent_vars = custom_gradient.get_dependent_variables( 1010 func_inputs, result_t) 1011 self.assertEqual(dependent_vars, [var]) 1012 1013 def testVariablesOutsideAndCustomGradient(self): 1014 with ops.Graph().as_default(): 1015 init = constant_op.constant(100.0, shape=(5,)) 1016 var = variables.Variable(init, shape=(5,)) 1017 1018 @custom_gradient.custom_gradient 1019 def _MyOnesLike(x): 1020 """Dummy version of ones_like which defines a gradient.""" 1021 1022 output = array_ops.ones_like(x) 1023 1024 def _Grad(dy): 1025 return array_ops.identity(dy) 1026 1027 return output, _Grad 1028 1029 def _Func(x): 1030 # non-differentiable operation with custom gradient. 1031 # The variable should be found. 1032 y = _MyOnesLike(var) 1033 return array_ops.identity(x) + 5.0 + y 1034 1035 input_t = constant_op.constant(2.0) 1036 result_t = _Func(input_t) 1037 dependent_vars = custom_gradient.get_dependent_variables( 1038 [input_t], [result_t]) 1039 self.assertEqual(dependent_vars, [var]) 1040 1041 1042class CustomGradientTest(test_util.TensorFlowTestCase): 1043 1044 def testCustomGradientTrivial(self): 1045 1046 @custom_gradient.custom_gradient 1047 def MyIdentity(x): 1048 1049 def Grad(dy): 1050 return [3 * dy] 1051 1052 return x, Grad 1053 1054 with ops.Graph().as_default(): 1055 x = constant(3.) 1056 y = MyIdentity(MyIdentity(x)) 1057 dy = gradients.gradients(y, x)[0] 1058 with session.Session(): 1059 self.assertEqual(9., self.evaluate(dy)) 1060 1061 def testCustomGradient(self): 1062 1063 @custom_gradient.custom_gradient 1064 def MyMultiply(x1, x2): 1065 result = x1 * x2 1066 1067 def Grad(dy): 1068 # Switched the ordering here. 1069 return [dy * x1, dy * x2] 1070 1071 return result, Grad 1072 1073 with ops.Graph().as_default(): 1074 x1 = constant(3.) 1075 x2 = constant(5.) 1076 y = MyMultiply(x1, x2) 1077 dy = gradients.gradients(y, [x1, x2]) 1078 with session.Session() as sess: 1079 self.assertAllEqual([3., 5.], self.evaluate(dy)) 1080 1081 def testCustomGradientClass(self): 1082 1083 class Model(object): 1084 1085 @custom_gradient.custom_gradient 1086 def Multiply(self, x1, x2): 1087 result = x1 * x2 1088 grad = lambda dy: (dy * x1, dy * x2) 1089 return result, grad 1090 1091 with ops.Graph().as_default(): 1092 x1 = constant(3.) 1093 x2 = constant(5.) 1094 m = Model() 1095 y = m.Multiply(x1, x2) 1096 dy = gradients.gradients(y, [x1, x2]) 1097 self.assertAllEqual([3., 5.], self.evaluate(dy)) 1098 1099 def testCustomGradientErrors(self): 1100 1101 @custom_gradient.custom_gradient 1102 def F(x): 1103 1104 def Grad(_): 1105 raise RuntimeError("x") 1106 1107 return x, Grad 1108 1109 with ops.Graph().as_default(): 1110 x = constant(1.0) 1111 y = F(x) 1112 with self.assertRaises(RuntimeError): 1113 gradients.gradients(y, x) 1114 1115 def testCustomGradientWithVariables(self): 1116 1117 @custom_gradient.custom_gradient 1118 def F(x): 1119 out = core_layers.dense(x, 3, use_bias=False) 1120 1121 def Grad(out_grad, variables=None): # pylint: disable=redefined-outer-name 1122 self.assertEqual(1, len(variables)) 1123 grads = gradients.gradients(out, [x, variables[0]], grad_ys=out_grad) 1124 return grads[0], [array_ops.ones((4, 3))] 1125 1126 return out, Grad 1127 1128 with ops.Graph().as_default(): 1129 x = array_ops.ones((2, 4)) 1130 with variable_scope.variable_scope("f", use_resource=True) as vs: 1131 y = F(x) 1132 all_vars = vs.global_variables() 1133 assert len(all_vars) == 1 1134 grads = gradients.gradients(y, [x, all_vars[0]]) 1135 for g in grads: 1136 self.assertTrue(g is not None) 1137 with session.Session() as sess: 1138 self.evaluate(variables.global_variables_initializer()) 1139 dw = sess.run(math_ops.reduce_sum(grads[1])) 1140 self.assertEqual(12., dw) 1141 1142 def testCustomGradientWithVariablesNoFalsePositives(self): 1143 1144 @custom_gradient.custom_gradient 1145 def F(x): 1146 out = core_layers.dense(x, 3, use_bias=False) 1147 1148 def Grad(out_grad, variables=None): # pylint: disable=redefined-outer-name 1149 self.assertEqual(1, len(variables)) 1150 grads = gradients.gradients(out, [x, variables[0]], grad_ys=out_grad) 1151 return grads[0], [array_ops.ones((3, 3))] 1152 1153 return out, Grad 1154 1155 with ops.Graph().as_default(): 1156 with variable_scope.variable_scope("f", use_resource=True) as vs: 1157 a = array_ops.ones((2, 4)) 1158 1159 # Variabes in these layers shouldn't be picked up by the decorator. 1160 b = core_layers.dense(a, 3, use_bias=False) 1161 c = core_layers.dense(b, 3, use_bias=False) 1162 x = core_layers.dense(b, 3, use_bias=False) + c 1163 1164 # Only the variables used in F. 1165 y = F(x) 1166 1167 all_vars = vs.global_variables() 1168 assert len(all_vars) == 4 1169 grads = gradients.gradients(y, [x] + all_vars) 1170 _, var_grads = grads[0], grads[1:] 1171 for g in grads: 1172 self.assertIsNotNone(g) 1173 with session.Session() as sess: 1174 self.evaluate(variables.global_variables_initializer()) 1175 dw = sess.run(math_ops.reduce_sum(var_grads[-1])) 1176 self.assertEqual(9., dw) 1177 1178 def testCustomGradientWithVariablesEager(self): 1179 with context.eager_mode(): 1180 layer = core_layers.Dense(4, use_bias=False) 1181 1182 @custom_gradient.custom_gradient 1183 def F(x): 1184 out = layer(x) 1185 1186 def Grad(out_grad, variables=None): # pylint: disable=redefined-outer-name 1187 del out_grad 1188 self.assertEqual(1, len(variables)) 1189 return (array_ops.ones((3, 2)), 1190 [array_ops.ones((2, 4))]) 1191 1192 return out, Grad 1193 1194 x = array_ops.ones((3, 2)) + 2. 1195 with backprop.GradientTape() as tape: 1196 tape.watch(x) 1197 y = F(x) 1198 w, = layer.variables 1199 dx, dw = tape.gradient(y, [x, w]) 1200 self.assertEqual(6., math_ops.reduce_sum(dx).numpy()) 1201 self.assertEqual(8., math_ops.reduce_sum(dw).numpy()) 1202 1203 @test_util.run_v1_only("b/120545219") 1204 def testCustomGradientErrorsWithNonResourceVariables(self): 1205 1206 def F(x, use_resource=False): 1207 with variable_scope.variable_scope("f", use_resource=use_resource): 1208 out = core_layers.dense(x, 4, use_bias=False) 1209 1210 def Grad(out_grad, variables=None): # pylint: disable=redefined-outer-name 1211 del out_grad 1212 self.assertEqual(1, len(variables)) 1213 return (array_ops.ones((3, 2)), [array_ops.ones((2, 4))]) 1214 1215 return out, Grad 1216 1217 @custom_gradient.custom_gradient 1218 def FResource(x): 1219 return F(x, use_resource=True) 1220 1221 @custom_gradient.custom_gradient 1222 def FNonResource(x): 1223 return F(x, use_resource=False) 1224 1225 x = array_ops.ones((3, 2)) + 2. 1226 1227 # Wrapping scope has use_resource=True but inner scope sets to False. Fails. 1228 with variable_scope.variable_scope("vs1", use_resource=True): 1229 with self.assertRaisesWithPredicateMatch(TypeError, 1230 "must be `ResourceVariable`s"): 1231 FNonResource(x) 1232 1233 # Wrapping scope has use_resource=False but inner scope sets to True. 1234 # Passes. 1235 with variable_scope.variable_scope("vs2", use_resource=False): 1236 FResource(x) 1237 1238 def testWithNumpyInputs(self): 1239 with context.eager_mode(): 1240 1241 @custom_gradient.custom_gradient 1242 def F(x): 1243 out = x 1244 1245 def Grad(_): 1246 return (None, None) 1247 1248 return out, Grad 1249 1250 x = np.ones((3, 2), dtype=np.float32) 1251 # Smoke test to ensure numpy inputs are accepted 1252 F(x) 1253 1254 @test_util.run_v1_only("b/120545219") 1255 def testRVGradientsDynamicCond(self): 1256 with self.cached_session(): 1257 alpha = resource_variable_ops.ResourceVariable( 1258 np.random.random((1,)), 1259 dtype="float32") 1260 1261 conditional = array_ops.placeholder_with_default(True, shape=()) 1262 output = control_flow_ops.cond( 1263 conditional, lambda: alpha * 2, lambda: alpha * 3) 1264 1265 g, = gradients_impl.gradients(output, alpha) 1266 self.evaluate(variables.global_variables_initializer()) 1267 self.assertAllEqual(g.eval(), [2.0]) 1268 self.assertAllEqual(g.eval(feed_dict={conditional: False}), [3.0]) 1269 1270 def testRecursiveCustomGradient(self): 1271 @custom_gradient.custom_gradient 1272 def F(x): 1273 out = core_layers.dense(x, 3, use_bias=False) 1274 1275 def Grad(out_grad, variables=None): # pylint: disable=redefined-outer-name 1276 self.assertEqual(1, len(variables)) 1277 grads = gradients.gradients(out, [x, variables[0]], grad_ys=out_grad) 1278 return grads[0], [array_ops.ones((4, 3))] 1279 1280 return out, Grad 1281 1282 @custom_gradient.custom_gradient 1283 def DoubleF(x): 1284 out = F(x) 1285 1286 def Grad(out_grad, variables=None): # pylint: disable=redefined-outer-name 1287 self.assertEqual(1, len(variables)) 1288 grads = gradients.gradients(out, [x, variables[0]], grad_ys=out_grad) 1289 return grads[0], [array_ops.ones((4, 3))] 1290 1291 return out, Grad 1292 with ops.Graph().as_default(): 1293 x = array_ops.ones((2, 4)) 1294 with variable_scope.variable_scope("f", use_resource=True) as vs: 1295 y = DoubleF(x) 1296 all_vars = vs.global_variables() 1297 assert len(all_vars) == 1 1298 grads = gradients.gradients(y, [x, all_vars[0]]) 1299 for g in grads: 1300 self.assertIsNotNone(g) 1301 with session.Session() as sess: 1302 self.evaluate(variables.global_variables_initializer()) 1303 dw = sess.run(math_ops.reduce_sum(grads[1])) 1304 self.assertEqual(12., dw) 1305 1306 1307class TensorListGradientsTest(test_util.TensorFlowTestCase): 1308 1309 def testDefaultGradYs(self): 1310 with ops.Graph().as_default(): 1311 tl = list_ops.empty_tensor_list( 1312 element_dtype=dtypes.float32, 1313 element_shape=ops.convert_to_tensor([], dtype=dtypes.int32)) 1314 a = constant(1.0) 1315 tl = list_ops.tensor_list_push_back(tl, a) 1316 1317 grad_tl = list_ops.empty_tensor_list( 1318 element_dtype=dtypes.float32, 1319 element_shape=ops.convert_to_tensor([], dtype=dtypes.int32)) 1320 grad_tl = list_ops.tensor_list_push_back(tl, constant(5.0)) 1321 1322 grad = gradients.gradients(tl, a, grad_ys=grad_tl)[0] 1323 with self.cached_session() as sess: 1324 self.assertEquals(self.evaluate(grad), 5.) 1325 1326 1327class TestKerasModelClass(training.Model): 1328 """A simple tensorflow keras Model class definition.""" 1329 1330 def __init__(self, width): 1331 super(TestKerasModelClass, self).__init__() 1332 1333 self.weight = variable_scope.get_variable( 1334 name="test_keras_var", 1335 shape=width, 1336 dtype=dtypes.float32, 1337 trainable=True, 1338 use_resource=True, 1339 ) 1340 1341 def call(self, inputs): 1342 return self.weight * inputs 1343 1344 1345class VariablesGradientTest(test_util.TensorFlowTestCase): 1346 1347 def _TestVariablesGradient(self, inputs, test_model, vars_to_grad): 1348 """Returns gradients of `test_model` with respect to `vars_to_grad`.""" 1349 1350 test_model_re = custom_gradient.recompute_grad(test_model) 1351 1352 with backprop.GradientTape(persistent=True) as tape: 1353 tape.watch(vars_to_grad) 1354 out_re = test_model_re(inputs) 1355 out = test_model(inputs) 1356 1357 grads_re = tape.gradient(out_re, vars_to_grad) 1358 grads = tape.gradient(out, vars_to_grad) 1359 1360 return grads_re, grads 1361 1362 def _TestFnVariablesGradient(self, inputs, test_fn, vars_to_grad): 1363 """Returns gradients of `test_model` with respect to `vars_to_grad`.""" 1364 1365 test_fn_re = custom_gradient.recompute_grad(test_fn) 1366 1367 with backprop.GradientTape(persistent=True) as tape: 1368 tape.watch(vars_to_grad) 1369 out_re = test_fn_re(inputs, vars_to_grad) 1370 out = test_fn(inputs, vars_to_grad) 1371 1372 grads_re = tape.gradient(out_re, vars_to_grad) 1373 grads = tape.gradient(out, vars_to_grad) 1374 1375 return grads_re, grads 1376 1377 @test_util.run_in_graph_and_eager_modes 1378 def testKerasRecompute(self): 1379 """Checks that recompute_grad works for a simple Keras Model.""" 1380 1381 test_model = TestKerasModelClass(10) 1382 test_input = constant(np.zeros((10, 10), dtype=np.float32)) 1383 self.evaluate(variables.global_variables_initializer()) 1384 test_model(test_input) # Ensures keras model is initialized. 1385 grads_re, grads = self._TestVariablesGradient(test_input, test_model, 1386 test_input) 1387 1388 grads_re = self.evaluate(grads_re) 1389 grads = self.evaluate(grads) 1390 for g, g_re in zip(grads, grads_re): 1391 self.assertAllClose(g, g_re) 1392 1393 grads_re, grads = self._TestVariablesGradient(test_input, test_model, 1394 test_model.variables) 1395 1396 grads_re = self.evaluate(grads_re) 1397 grads = self.evaluate(grads) 1398 for g, g_re in zip(grads, grads_re): 1399 self.assertAllClose(g, g_re) 1400 1401 @test_util.run_in_graph_and_eager_modes 1402 def testFnRecompute(self): 1403 """Checks that recompute_grad works grads of function args.""" 1404 1405 def TestFn(inputs, input_vars): 1406 return inputs * input_vars 1407 1408 def TestFnSeq(inputs, input_vars): 1409 return (inputs * input_vars, inputs * input_vars * 2.0) 1410 1411 with variable_scope.variable_scope("test", use_resource=True): 1412 test_var = variable_scope.get_variable( 1413 name="test_var", 1414 shape=10, 1415 trainable=True, 1416 ) 1417 1418 test_input = constant(np.zeros((10, 10), dtype=np.float32)) 1419 1420 grads_re, grads = self._TestFnVariablesGradient(test_input, TestFn, 1421 test_input) 1422 1423 grads_re = self.evaluate(grads_re) 1424 grads = self.evaluate(grads) 1425 for g, g_re in zip(grads, grads_re): 1426 self.assertAllClose(g, g_re) 1427 1428 grads_re, grads = self._TestFnVariablesGradient(test_input, TestFn, 1429 test_var) 1430 grads_re = self.evaluate(grads_re) 1431 grads = self.evaluate(grads) 1432 for g, g_re in zip(grads, grads_re): 1433 self.assertAllClose(g, g_re) 1434 1435 # Regression test for wrapping sequence outputting functions. 1436 grads_re, grads = self._TestFnVariablesGradient(test_input, TestFnSeq, 1437 test_input) 1438 grads_re = self.evaluate(grads_re) 1439 grads = self.evaluate(grads) 1440 for g, g_re in zip(grads, grads_re): 1441 self.assertAllClose(g, g_re) 1442 1443 grads_re, grads = self._TestFnVariablesGradient(test_input, TestFnSeq, 1444 test_var) 1445 grads_re = self.evaluate(grads_re) 1446 grads = self.evaluate(grads) 1447 for g, g_re in zip(grads, grads_re): 1448 self.assertAllClose(g, g_re) 1449 1450 @test_util.deprecated_graph_mode_only 1451 def testFnRecomputeWithScopeGradientTape(self): 1452 """Checks that recompute_grad works with var scope and GradientTape.""" 1453 1454 def TestFn(input_t): 1455 with variable_scope.variable_scope("inner_scope"): 1456 test_var = variable_scope.get_variable( 1457 name="test_var", 1458 shape=10, 1459 trainable=True, 1460 ) 1461 return input_t * test_var 1462 1463 test_input_t = constant(np.zeros((10, 10), dtype=np.float32)) 1464 1465 with variable_scope.variable_scope( 1466 "output_scope", reuse=variable_scope.AUTO_REUSE, use_resource=True): 1467 test_fn_re = custom_gradient.recompute_grad(TestFn) 1468 1469 with backprop.GradientTape(persistent=True) as tape: 1470 out_re = test_fn_re(test_input_t) 1471 out = TestFn(test_input_t) 1472 1473 grads_re = tape.gradient(out_re, variables.trainable_variables()) 1474 grads = tape.gradient(out, variables.trainable_variables()) 1475 1476 grads_re = self.evaluate(grads_re) 1477 grads = self.evaluate(grads) 1478 for g, g_re in zip(grads, grads_re): 1479 self.assertAllClose(g, g_re) 1480 self.assertAllClose(g, g_re) 1481 1482 @test_util.deprecated_graph_mode_only 1483 def testFnRecomputeWithScopeGradients(self): 1484 """Checks that recompute_grad works with var scope and gradients(..).""" 1485 1486 def TestFn(input_t): 1487 with variable_scope.variable_scope("inner_scope"): 1488 test_var = variable_scope.get_variable( 1489 name="test_var", 1490 shape=10, 1491 trainable=True, 1492 ) 1493 return input_t * test_var 1494 1495 test_input_t = constant(np.zeros((10, 10), dtype=np.float32)) 1496 1497 with variable_scope.variable_scope( 1498 "output_scope", reuse=variable_scope.AUTO_REUSE, use_resource=True): 1499 test_fn_re = custom_gradient.recompute_grad(TestFn) 1500 out_re = test_fn_re(test_input_t) 1501 out = TestFn(test_input_t) 1502 1503 grads_re = gradients.gradients(out_re, variables.trainable_variables()) 1504 grads = gradients.gradients(out, variables.trainable_variables()) 1505 1506 grads_re = self.evaluate(grads_re) 1507 grads = self.evaluate(grads) 1508 for g, g_re in zip(grads, grads_re): 1509 self.assertAllClose(g, g_re) 1510 self.assertAllClose(g, g_re) 1511 1512 @test_util.run_in_graph_and_eager_modes 1513 def testFnRecomputeSameTensor(self): 1514 """Check recompute_grad when wrapped f called as f(x, x) - b/147369366.""" 1515 1516 def TestFnMul(x, y): 1517 return x * y 1518 1519 def TestFnSingleVar(x, y): 1520 # pylint: disable=unused-argument 1521 return x 1522 1523 with variable_scope.variable_scope("test", use_resource=True): 1524 x = array_ops.ones((10)) 1525 1526 grads_re, grads = self._TestFnVariablesGradient(x, TestFnMul, 1527 x) 1528 grads_re = self.evaluate(grads_re) 1529 grads = self.evaluate(grads) 1530 for g, g_re in zip(grads, grads_re): 1531 self.assertAllClose(g, g_re) 1532 1533 grads_re, grads = self._TestFnVariablesGradient(x, TestFnSingleVar, 1534 x) 1535 grads_re = self.evaluate(grads_re) 1536 grads = self.evaluate(grads) 1537 for g, g_re in zip(grads, grads_re): 1538 self.assertAllClose(g, g_re) 1539 1540 1541class GradPassThroughTest(test_util.TensorFlowTestCase): 1542 1543 @test_util.run_v1_only("b/120545219") 1544 def test_gradients_v1(self): 1545 x = variable_scope.get_variable( 1546 name="x", shape=(), initializer=init_ops.constant_initializer(1.0), 1547 use_resource=True) 1548 z = variable_scope.get_variable( 1549 name="z", shape=(), initializer=init_ops.constant_initializer(3.0), 1550 use_resource=True) 1551 1552 # Verify that assign op is not differentiable 1553 y = state_ops.assign(x, z**2) 1554 grads = gradients.gradients(y, z) 1555 self.assertIsNone(grads[0]) 1556 1557 # Verify that when the (non differentiable) assign op is wrapped with 1558 # grad_pass_through, gradients are correctly forwarded to the inputs. 1559 # Form an input as quadratic function of variable z and check that the 1560 # gradient of output wrt to z is correct. 1561 y = custom_gradient.grad_pass_through( 1562 lambda v: state_ops.assign(x, v))(z**2) 1563 grads = gradients.gradients(y, z) 1564 with self.cached_session() as sess: 1565 sess.run(variables.global_variables_initializer()) 1566 self.assertAllClose(grads[0].eval(), 6.0) 1567 1568 # Verify that variables involved in the wrapped op do not receive gradients. 1569 y = custom_gradient.grad_pass_through(lambda v: x * v)(z) 1570 grads = gradients.gradients(y, x) 1571 self.assertIsNone(grads[0]) 1572 1573 @test_util.run_v2_only 1574 def test_gradients_v2(self): 1575 x = variables.Variable(1.0, name="x") 1576 z = variables.Variable(3.0, name="z") 1577 1578 # Verify that assign op is not differentiable 1579 with backprop.GradientTape() as tape: 1580 y = x.assign(z**2) 1581 grads = tape.gradient(y, z) 1582 self.assertIsNone(grads) 1583 1584 # Verify that when the (non differentiable) assign op is wrapped with 1585 # grad_pass_through, gradients are correctly forwarded to the inputs. 1586 # Form an input as quadratic function of variable z and check that the 1587 # gradient of output wrt to z is correct. 1588 with backprop.GradientTape() as tape: 1589 y = custom_gradient.grad_pass_through(x.assign)(z**2) 1590 grads = tape.gradient(y, z) 1591 self.assertAllClose(grads, 6.0) 1592 1593 # Verify that variables involved in the wrapped op do not receive gradients. 1594 with backprop.GradientTape() as tape: 1595 y = custom_gradient.grad_pass_through(lambda v: x * v)(z) 1596 grads = tape.gradient(y, x) 1597 self.assertIsNone(grads) 1598 1599 1600if __name__ == "__main__": 1601 googletest.main() 1602