• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for tensorflow.ops.gradients."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20import sys
21import warnings
22
23from absl.testing import parameterized
24import numpy as np
25from tensorflow.python.client import session
26from tensorflow.python.eager import backprop
27from tensorflow.python.eager import context
28from tensorflow.python.eager import function
29from tensorflow.python.framework import constant_op
30from tensorflow.python.framework import dtypes
31from tensorflow.python.framework import function as framework_function
32from tensorflow.python.framework import ops
33from tensorflow.python.framework import test_ops
34from tensorflow.python.framework import test_util
35from tensorflow.python.framework.constant_op import constant
36from tensorflow.python.keras.engine import training
37from tensorflow.python.layers import core as core_layers
38from tensorflow.python.ops import array_grad  # pylint: disable=unused-import
39from tensorflow.python.ops import array_ops
40from tensorflow.python.ops import control_flow_grad  # pylint: disable=unused-import
41from tensorflow.python.ops import control_flow_ops
42from tensorflow.python.ops import custom_gradient
43from tensorflow.python.ops import data_flow_grad  # pylint: disable=unused-import
44from tensorflow.python.ops import data_flow_ops  # pylint: disable=unused-import
45from tensorflow.python.ops import functional_ops  # pylint: disable=unused-import
46from tensorflow.python.ops import gradients
47from tensorflow.python.ops import gradients_impl
48from tensorflow.python.ops import init_ops
49from tensorflow.python.ops import list_ops
50from tensorflow.python.ops import math_grad  # pylint: disable=unused-import
51from tensorflow.python.ops import math_ops
52from tensorflow.python.ops import nn_grad  # pylint: disable=unused-import
53from tensorflow.python.ops import resource_variable_ops
54from tensorflow.python.ops import state_grad  # pylint: disable=unused-import
55from tensorflow.python.ops import state_ops
56from tensorflow.python.ops import tensor_array_grad  # pylint: disable=unused-import
57from tensorflow.python.ops import tensor_array_ops
58from tensorflow.python.ops import unconnected_gradients
59from tensorflow.python.ops import variable_scope
60from tensorflow.python.ops import variables
61from tensorflow.python.ops.nn_ops import bias_add
62from tensorflow.python.platform import googletest
63
64
65class GradientsTest(test_util.TensorFlowTestCase, parameterized.TestCase):
66
67  def testGradients(self):
68    with ops.Graph().as_default():
69      inp = constant(1.0, shape=[32, 100], name="in")
70      w = constant(1.0, shape=[100, 10], name="w")
71      b = constant(1.0, shape=[10], name="b")
72      xw = math_ops.matmul(inp, w, name="xw")
73      h = bias_add(xw, b, name="h")
74      w_grad = gradients.gradients(h, w)[0]
75    self.assertEquals("MatMul", w_grad.op.type)
76    self.assertEquals(w_grad.op._original_op, xw.op)
77    self.assertTrue(w_grad.op.get_attr("transpose_a"))
78    self.assertFalse(w_grad.op.get_attr("transpose_b"))
79
80  def testUnusedOutput(self):
81    with ops.Graph().as_default():
82      w = constant(1.0, shape=[2, 2])
83      x = constant(1.0, shape=[2, 2])
84      wx = math_ops.matmul(w, x)
85      split_wx = array_ops.split(value=wx, num_or_size_splits=2, axis=0)
86      c = math_ops.reduce_sum(split_wx[1])
87      gw = gradients.gradients(c, [w])[0]
88    self.assertEquals("MatMul", gw.op.type)
89
90  def testColocateGradients(self):
91    with ops.Graph().as_default() as g:
92      w = constant(1.0, shape=[1, 1])
93      x = constant(1.0, shape=[1, 2])
94      with g.device("/device:GPU:0"):
95        wx = math_ops.matmul(w, x)
96      gw = gradients.gradients(wx, [w], colocate_gradients_with_ops=True)[0]
97    self.assertEqual(gw.op.colocation_groups(), wx.op.colocation_groups())
98
99  def testColocateGradientsWithAggregation(self):
100    with ops.Graph().as_default() as g:
101      with g.device("/device:GPU:1"):
102        w = constant(1.0, shape=[1, 1])
103      x = constant(1.0, shape=[1, 2])
104      y = constant(1.0, shape=[1, 2])
105      wx = math_ops.matmul(w, x)
106      wy = math_ops.matmul(w, y)
107      with g.device("/device:GPU:0"):
108        z = wx + wy
109
110      gw1 = gradients.gradients(z, [w], colocate_gradients_with_ops=True)[0]
111      self.assertEqual(gw1.op.colocation_groups(), wx.op.colocation_groups())
112
113      gw2 = gradients.gradients(z, [w], colocate_gradients_with_ops=False)[0]
114      self.assertTrue(wx.op.colocation_groups() != gw2.op.colocation_groups())
115
116  def testColocateGradientsWithAggregationInMultipleDevices(self):
117    with ops.Graph().as_default() as g:
118      with g.device("/device:GPU:1"):
119        w = constant(1.0, shape=[1, 1])
120      x = constant(1.0, shape=[1, 2])
121      y = constant(1.0, shape=[1, 2])
122      with g.device("/task:1"):
123        wx = math_ops.matmul(w, x)
124      with g.device("/task:2"):
125        wy = math_ops.matmul(w, y)
126      with g.device("/device:GPU:0"):
127        z = wx + wy
128
129      gw1 = gradients.gradients(z, [w], colocate_gradients_with_ops=True)[0]
130      self.assertEqual(gw1.op.colocation_groups(), w.op.colocation_groups())
131
132      gw2 = gradients.gradients(z, [w], colocate_gradients_with_ops=False)[0]
133      self.assertTrue(w.op.colocation_groups() != gw2.op.colocation_groups())
134
135  def testColocateGradientsWithGateGradients(self):
136    if not test_util.is_gpu_available():
137      self.skipTest("No GPU available")
138    with ops.Graph().as_default() as g:
139      with g.device("/device:CPU:0"):
140        x = constant(1.0, shape=[1, 1])
141        y = constant(1.0, shape=[1, 1])
142        s = x + y
143      with g.device("/device:GPU:0"):
144        z = math_ops.reduce_sum(s)
145
146      gz_x = gradients.gradients(z, [x], colocate_gradients_with_ops=True,
147                                 gate_gradients=True)[0]
148      with session.Session():
149        # Make sure the placer doesn't complain.
150        self.evaluate(gz_x)
151
152  def testBoundaryStop(self):
153    # Test that we don't differentiate 'x'. The gradient function for 'x' is
154    # set explicitly to None so we will get an exception if the gradient code
155    # tries to differentiate 'x'.
156    with ops.Graph().as_default():
157      c = constant(1.0)
158      x = array_ops.identity(c)
159      y = x + 1.0
160      z = y + 1
161      grads = gradients.gradients(z, [x])
162      self.assertTrue(all(x is not None for x in grads))
163
164  @test_util.run_v1_only("b/120545219")
165  def testBoundaryContinue(self):
166    # Test that we differentiate both 'x' and 'y' correctly when x is a
167    # predecessor of y.
168    with self.cached_session():
169      x = constant(1.0)
170      y = x * 2.0
171      z = y * 3.0
172      grads = gradients.gradients(z, [x, y])
173      self.assertTrue(all(x is not None for x in grads))
174      self.assertEqual(6.0, grads[0].eval())
175
176  @test_util.run_v1_only("b/120545219")
177  def testAggregationMethodAccumulateN(self):
178    with self.cached_session():
179      x = constant(1.0)
180      y = x * 2.0
181      z = y + y + y + y + y + y + y + y + y + y
182      grads = gradients.gradients(
183          z, [x, y],
184          aggregation_method=gradients.AggregationMethod.
185          EXPERIMENTAL_ACCUMULATE_N)
186      self.assertTrue(all(x is not None for x in grads))
187      self.assertEqual(20.0, grads[0].eval())
188      self.assertEqual(10.0, grads[1].eval())
189
190  @test_util.run_v1_only("b/120545219")
191  def testAggregationMethodAddN(self):
192    with self.cached_session():
193      x = constant(1.0)
194      y = x * 2.0
195      z = y + y + y + y + y + y + y + y + y + y
196      grads = gradients.gradients(
197          z, [x, y], aggregation_method=gradients.AggregationMethod.ADD_N)
198      self.assertTrue(all(x is not None for x in grads))
199      self.assertEqual(20.0, grads[0].eval())
200      self.assertEqual(10.0, grads[1].eval())
201
202  @test_util.run_v1_only("b/120545219")
203  def testAggregationMethodTree(self):
204    with self.cached_session():
205      x = constant(1.0)
206      y = x * 2.0
207      z = y + y + y + y + y + y + y + y + y + y
208      grads = gradients.gradients(
209          z, [x, y],
210          aggregation_method=gradients.AggregationMethod.EXPERIMENTAL_TREE)
211      self.assertTrue(all(x is not None for x in grads))
212      self.assertEqual(20.0, grads[0].eval())
213      self.assertEqual(10.0, grads[1].eval())
214
215  def testNoGradientForStringOutputs(self):
216    with ops.Graph().as_default():
217
218      def _TestOpGrad(_, float_grad, string_grad):
219        """Gradient function for TestStringOutput."""
220        self.assertEquals(float_grad.dtype, dtypes.float32)
221        self.assertFalse(string_grad)
222        return float_grad
223
224      ops.RegisterGradient("TestStringOutput")(_TestOpGrad)
225
226      c = constant(1.0)
227      x, _ = test_ops.test_string_output(c)
228      z = x * 2.0
229      w = z * 3.0
230      grads = gradients.gradients(z, [c])
231      self.assertIsInstance(grads[0], ops.Tensor)
232      grads = gradients.gradients(w, [c])
233      self.assertIsInstance(grads[0], ops.Tensor)
234
235  def testNoGradientForStringOutputsWithOpNamespace(self):
236    with ops.Graph().as_default():
237
238      def _TestOpGrad(_, float_grad, string_grad):
239        """Gradient function for TestStringOutput."""
240        self.assertEqual(float_grad.dtype, dtypes.float32)
241        self.assertFalse(string_grad)
242        return float_grad
243
244      ops.RegisterGradient("Namespace>TestStringOutput")(_TestOpGrad)
245
246      c = constant(1.0)
247      x, _ = test_ops.namespace_test_string_output(c)
248      z = x * 2.0
249      w = z * 3.0
250      grads = gradients.gradients(z, [c])
251      self.assertIsInstance(grads[0], ops.Tensor)
252      grads = gradients.gradients(w, [c])
253      self.assertIsInstance(grads[0], ops.Tensor)
254
255  def testSingletonIndexedSlices(self):
256    with ops.Graph().as_default():
257      x = array_ops.placeholder(dtypes.float32)
258      y = array_ops.identity(x)
259      dy = ops.IndexedSlices(
260          array_ops.placeholder(dtypes.float32),
261          array_ops.placeholder(dtypes.int32))
262      dx, = gradients.gradients(y, x, grad_ys=dy)
263      # The IndexedSlices gradient of tf.identity is the identity map.
264      with self.cached_session() as sess:
265        vdx, vdy = sess.run(
266            [dx, dy], feed_dict={x: [1.0], dy.indices: [0], dy.values: [2.0]})
267      self.assertEqual(vdx, vdy)
268
269  @test_util.run_v1_only("b/120545219")
270  def testNonDifferentiableSwitchInWhileLoop(self):
271    with ops.Graph().as_default():
272      v = array_ops.placeholder(dtypes.float32, [])
273
274      def _Step(i, a, ta):
275        a += math_ops.cast(v, dtypes.int32)
276        return (i + 1, a, ta.write(i, a))
277
278      n = 4
279      i, _, ta = control_flow_ops.while_loop(
280          lambda i, *_: i < n,
281          _Step, [0, 0, tensor_array_ops.TensorArray(
282              dtypes.int32, size=n)])
283      target = ta.read(i - 1)
284      grad, = gradients.gradients(target, v)
285      self.assertIsNone(grad)
286
287  def testVariableReadValueGradient(self):
288    with ops.Graph().as_default():
289      init = constant_op.constant(100.0)
290      var = variables.Variable(init)
291      gradient = gradients.gradients(var.read_value(), var)
292      self.assertIsNotNone(gradient)
293
294  @parameterized.parameters(dtypes.float32, dtypes.float64)
295  def testVariableDefaultGrad(self, dtype):
296    with ops.Graph().as_default():
297      init = constant_op.constant(100.0, dtype=dtype)
298      var = variables.Variable(init)
299      dummy_const = constant_op.constant(0.0)
300      gradient = gradients.gradients(
301          dummy_const,
302          var,
303          unconnected_gradients=unconnected_gradients.UnconnectedGradients.ZERO
304      )[0]
305      self.assertEqual(gradient.dtype, dtype)
306      self.assertIsNotNone(gradient)
307
308  def testVariableAsGraphElementGradient(self):
309    with ops.Graph().as_default() as graph:
310      init = constant_op.constant(100.0)
311      var = variables.Variable(init)
312      gradient = gradients.gradients(graph.as_graph_element(var), var)
313      self.assertIsNotNone(gradient)
314
315  @test_util.run_v1_only("b/120545219")
316  def testVariableRefGradient(self):
317    with ops.Graph().as_default():
318      init = constant_op.constant(100.0)
319      var = variables.VariableV1(init)
320      gradient = gradients.gradients(var._ref(), var)
321      self.assertIsNotNone(gradient)
322
323  @test_util.run_v1_only("b/120545219")
324  def testDependentYs(self):
325    with self.cached_session():
326      x = constant_op.constant(3.0)
327      y = math_ops.square(x)
328      y1 = math_ops.square(y)
329      y2 = math_ops.square(y1)
330      g = gradients.gradients([y, y2], x)
331      self.assertAllClose(17502.0, g[0].eval())
332      g = gradients.gradients(y + y2, x)
333      self.assertAllClose(17502.0, g[0].eval())
334      z = array_ops.identity(y)
335      z2 = array_ops.identity(y2)
336      g = gradients.gradients([z, z2], x)
337      self.assertAllClose(17502.0, g[0].eval())
338
339  @test_util.run_v1_only("b/120545219")
340  def testPartialDerivatives(self):
341    with self.cached_session():
342      x = constant_op.constant(1.)
343      y = 2 * x
344      z = x + y
345      totalg = gradients.gradients(z, [x, y])
346      self.assertEqual([3.0, 1.0], [g.eval() for g in totalg])
347      partialg = gradients.gradients(z, [x, y], stop_gradients=[x, y])
348      self.assertEqual([1.0, 1.0], [g.eval() for g in partialg])
349
350  @test_util.run_v1_only("b/120545219")
351  def testStopGradients(self):
352    def _MakeGraph(rng, stop_gradients=()):
353      def _FunctionOf(xs, k=3):
354        return ops.convert_to_tensor(
355            sum(math_ops.matmul(rng.rand(k, k), x) for x in xs)
356            + rng.rand(k, k))
357
358      a = _FunctionOf([])
359      if "a" in stop_gradients: a = array_ops.stop_gradient(a)
360      b = _FunctionOf([a])
361      if "b" in stop_gradients: b = array_ops.stop_gradient(b)
362      c = _FunctionOf([a, b])
363      if "c" in stop_gradients: c = array_ops.stop_gradient(c)
364      d = _FunctionOf([b, c])
365      if "d" in stop_gradients: d = array_ops.stop_gradient(d)
366      return dict(a=a, b=b, c=c, d=d)
367
368    def _Gradients(ys, xs, **kwargs):
369      dydxs = gradients.gradients(ys, xs, **kwargs)
370      dydxs = [0. * x if dydx is None else dydx
371               for x, dydx in zip(xs, dydxs)]
372      return dydxs
373
374    seed = np.random.randint(1000)
375    cases = []
376    subsets = [""] + "a b c d ab ac ad bc bd cd abc abd acd bcd abcd".split()
377    graph = _MakeGraph(np.random.RandomState(seed))
378    for constants in subsets:
379      graph_with_stops = _MakeGraph(np.random.RandomState(seed), constants)
380      for variables_ in subsets:
381        # compute the gradient when stopped using tf.stop_gradients
382        grad1 = _Gradients([graph_with_stops["d"]],
383                           [graph_with_stops[v] for v in variables_])
384        # compute the gradient when stopped using the stop_gradients kwarg
385        grad2 = _Gradients([graph["d"]],
386                           [graph[v] for v in variables_],
387                           stop_gradients=[graph[v] for v in constants])
388        cases.append(dict(grad1=grad1, grad2=grad2,
389                          constants=constants, variables=variables_))
390
391    # evaluate all tensors in one call to session.run for speed
392    with self.cached_session() as sess:
393      results = sess.run([(case["grad1"], case["grad2"]) for case in cases])
394
395    for (npgrad1, npgrad2), case in zip(results, cases):
396      for a, b in zip(npgrad1, npgrad2):
397        np.testing.assert_allclose(a, b)
398
399  def testUnconnectedGradientsNoneUnconnectedGradients(self):
400    with ops.Graph().as_default():
401      x = constant(1.0, shape=[2, 2])
402      y = constant(3.0, shape=[3, 1])
403      grad = gradients.gradients(
404          [y], [x], unconnected_gradients="none")
405    self.assertIsNone(grad[0])
406
407  def testUnconnectedGradientsZerosUnconnectedGradients(self):
408    with ops.Graph().as_default():
409      x = constant(1.0, shape=[2, 2])
410      y = constant(3.0, shape=[3, 1])
411      grads = gradients.gradients(
412          [y], [x], unconnected_gradients="zero")
413      with self.cached_session() as sess:
414        self.assertAllEqual([[0.0, 0.0], [0.0, 0.0]], self.evaluate(grads)[0])
415
416  def testUnconnectedGradientsZeroConnectedGradients(self):
417    with ops.Graph().as_default():
418      x = constant(1.0)
419      y = x * 3.0
420      grad = gradients.gradients(
421          [y], [x], unconnected_gradients="zero")
422      with self.cached_session() as sess:
423        self.assertEquals(3.0, self.evaluate(grad)[0])
424
425  def testUnknownUnconnectedGradientsValueGiven(self):
426    with ops.Graph().as_default():
427      x = constant(1.0)
428      y = constant(1.0)
429      with self.assertRaisesRegexp(
430          ValueError, "Unknown value for unconnected_gradients: 'nonsense'"):
431        gradients.gradients([y], [x], unconnected_gradients="nonsense")
432
433
434class FunctionGradientsTest(test_util.TensorFlowTestCase):
435
436  @classmethod
437  def XSquarePlusB(cls, x, b):
438    return x * x + b
439
440  @classmethod
441  def XSquarePlusBGradient(cls, x, b, g):
442    # Perturb gradients (multiply by 2), so we can test that this was called.
443    g *= 2.0
444    return g * 2.0 * x, g
445
446  @classmethod
447  def _PythonGradient(cls, op, grad):
448    # Perturb gradients (multiply by 3), so we can test that this was called.
449    grad *= 3.0
450    return grad * op.inputs[0] * 2.0, grad
451
452  @classmethod
453  def _GetFunc(cls, **kwargs):
454    return framework_function.Defun(dtypes.float32, dtypes.float32, **
455                                    kwargs)(cls.XSquarePlusB)
456
457  def _GetFuncGradients(self, f, x_value, b_value):
458    x = constant_op.constant(x_value, name="x")
459    b = constant_op.constant(b_value, name="b")
460
461    y = f(x, b)
462    grads = gradients.gradients(y, [x, b])
463    with self.cached_session() as sess:
464      return sess.run(grads)
465
466  def testFunctionGradientsBasic(self):
467    g = ops.Graph()
468    with g.as_default():
469      f = self._GetFunc()
470      # Get gradients (should add SymbolicGradient node for function).
471      grads = self._GetFuncGradients(f, [2.0], [1.0])
472      self.assertAllEqual([4.0], grads[0])
473      self.assertAllEqual([1.0], grads[1])
474
475  def testFunctionGradientsComposition(self):
476    with ops.Graph().as_default():
477      f = self._GetFunc()
478      x = constant_op.constant([2.0], name="x")
479      b1 = constant_op.constant([1.0], name="b1")
480      b2 = constant_op.constant([1.0], name="b2")
481
482      y = f(f(x, b1), b2)
483      # Build gradient graph (should add SymbolicGradient node for function).
484      grads = gradients.gradients(y, [x, b1])
485
486      with self.cached_session() as sess:
487        self.assertAllEqual([40.0], self.evaluate(grads)[0])
488        self.assertAllEqual([10.0], self.evaluate(grads)[1])
489
490  def testFunctionGradientsWithGradFunc(self):
491    g = ops.Graph()
492    with g.as_default():
493      grad_func = framework_function.Defun(dtypes.float32, dtypes.float32,
494                                           dtypes.float32)(
495                                               self.XSquarePlusBGradient)
496      f = self._GetFunc(grad_func=grad_func)
497      # Get gradients (should add SymbolicGradient node for function, which
498      # uses the grad_func above, which multiplies all gradients by 2).
499      grads = self._GetFuncGradients(f, [2.0], [1.0])
500      self.assertAllEqual([4.0 * 2], grads[0])
501      self.assertAllEqual([1.0 * 2], grads[1])
502
503  def testFunctionGradientWithRegistration(self):
504    g = ops.Graph()
505    with g.as_default():
506      f = self._GetFunc(python_grad_func=self._PythonGradient)
507      # Get gradients, using the python gradient function. It multiplies the
508      # gradients by 3.
509      grads = self._GetFuncGradients(f, [2.0], [1.0])
510      self.assertAllEqual([4.0 * 3], grads[0])
511      self.assertAllEqual([1.0 * 3], grads[1])
512
513  def testFunctionGradientWithGradFuncAndRegistration(self):
514    g = ops.Graph()
515    with g.as_default():
516      grad_func = framework_function.Defun(dtypes.float32, dtypes.float32,
517                                           dtypes.float32)(
518                                               self.XSquarePlusBGradient)
519      with self.assertRaisesRegexp(ValueError, "Gradient defined twice"):
520        f = self._GetFunc(
521            grad_func=grad_func, python_grad_func=self._PythonGradient)
522        f.add_to_graph(ops.Graph())
523
524  def testGradientWrtCaptured(self):
525    with ops.Graph().as_default():
526      x = constant_op.constant(1.0, name="x")
527
528      @function.defun()
529      def Foo():
530        y = math_ops.multiply(x, 2.0, name="y")
531        g = gradients_impl.gradients(y, x)
532        return g[0]
533
534      f = Foo()
535      with self.cached_session() as sess:
536        self.assertEqual(self.evaluate(f), 2.0)
537
538  def testGradientOfCaptured(self):
539    with ops.Graph().as_default():
540      x = constant_op.constant(1.0, name="x")
541      y = math_ops.multiply(x, 2.0, name="y")
542
543      @framework_function.Defun()
544      def Foo():
545        g = gradients_impl.gradients(y, x)
546        return g[0]
547
548      f = Foo()
549      with self.cached_session() as sess:
550        self.assertEqual(self.evaluate(f), 2.0)
551
552  def testCapturedResourceVariable(self):
553    with ops.Graph().as_default():
554      var = resource_variable_ops.ResourceVariable(1.0, name="var")
555
556      @function.defun()
557      def Foo():
558        y = math_ops.multiply(var, 2.0, name="y")
559        g = gradients_impl.gradients(y, var)
560        return g[0]
561
562      f = Foo()
563      with self.cached_session() as sess:
564        self.evaluate(variables.global_variables_initializer())
565        self.assertEqual(self.evaluate(f), 2.0)
566
567  def testCapturedNested(self):
568    with ops.Graph().as_default():
569      x1 = constant_op.constant(1.0, name="x1")
570      x2 = constant_op.constant(2.0, name="x2")
571      x3 = math_ops.multiply(x1, x2, name="x3")
572
573      @function.defun()
574      def Outer():
575        outer1 = array_ops.identity(x1, name="outer1")
576
577        @function.defun()
578        def Inner():
579          inner1 = array_ops.identity(outer1, name="inner1")
580          inner2 = array_ops.identity(x2, name="inner2")
581          inner3 = array_ops.identity(x3, name="inner3")
582          return gradients_impl.gradients([inner1, inner2, inner3, x1],
583                                          [x1, x2])
584
585        return Inner()
586
587      x1_grad, x2_grad = Outer()
588      with self.cached_session() as sess:
589        # 1.0 + None + 2.0 + 1.0 = 4.0
590        self.assertEqual(self.evaluate(x1_grad), 4.0)
591        # None + 1.0 + 1.0 + None = 2.0
592        self.assertEqual(self.evaluate(x2_grad), 2.0)
593
594  def testCapturedFromFunction(self):
595    with ops.Graph().as_default():
596      x = constant_op.constant(1.0, name="x")
597
598      @function.defun()
599      def Outer():
600        y = math_ops.multiply(x, 2.0, name="y")
601
602        @function.defun()
603        def Inner():
604          z = math_ops.multiply(y, 3.0, name="z")
605          g = gradients_impl.gradients(z, y)
606          return g[0]
607
608        return Inner()
609
610      z_grad = Outer()
611      with self.cached_session() as sess:
612        self.assertEqual(self.evaluate(z_grad), 3.0)
613
614  def testCapturedEagerTensors(self):
615    # Test that we can handle captured eager tensors unrelated to the gradient
616    # computation (i.e. we need to ignore them).
617    # TODO(skyewm): make it an error if you try to take the gradient wrt a
618    # captured EagerTensor
619    with context.eager_mode():
620      c = constant_op.constant(2.0, name="c")
621
622      @function.defun
623      def Foo():
624        x = constant_op.constant(10.0, name="x")
625        y = math_ops.multiply(x, c, name="y")
626        # Regression test for b/122564611.
627        z = math_ops.multiply(c, y, name="z")
628        g = gradients_impl.gradients(z, x)
629        return g[0]
630
631      self.assertEqual(Foo().numpy(), 4.0)
632
633
634class StopGradientTest(test_util.TensorFlowTestCase):
635
636  def testStopGradient(self):
637    with ops.Graph().as_default():
638      inp = constant(1.0, shape=[100, 32], name="in")
639      out = array_ops.stop_gradient(inp)
640      igrad = gradients.gradients(out, inp)[0]
641    assert igrad is None
642
643
644class PreventGradientTest(test_util.TensorFlowTestCase):
645
646  def testPreventGradient(self):
647    with ops.Graph().as_default():
648      inp = constant(1.0, shape=[100, 32], name="in")
649      out = array_ops.prevent_gradient(inp)
650      with self.assertRaisesRegexp(LookupError, "explicitly disabled"):
651        _ = gradients.gradients(out, inp)
652
653
654class HessianVectorProductTest(test_util.TensorFlowTestCase):
655
656  @test_util.run_v1_only("b/120545219")
657  def testHessianVectorProduct(self):
658    # Manually compute the Hessian explicitly for a low-dimensional problem
659    # and check that HessianVectorProduct matches multiplication by the
660    # explicit Hessian.
661    # Specifically, the Hessian of f(x) = x^T A x is
662    # H = A + A^T.
663    # We expect HessianVectorProduct(f(x), x, v) to be H v.
664    m = 4
665    rng = np.random.RandomState([1, 2, 3])
666    mat_value = rng.randn(m, m).astype("float32")
667    v_value = rng.randn(m, 1).astype("float32")
668    x_value = rng.randn(m, 1).astype("float32")
669    hess_value = mat_value + mat_value.T
670    hess_v_value = np.dot(hess_value, v_value)
671    for use_gpu in [False, True]:
672      with self.cached_session(use_gpu=use_gpu):
673        mat = constant_op.constant(mat_value)
674        v = constant_op.constant(v_value)
675        x = constant_op.constant(x_value)
676        mat_x = math_ops.matmul(mat, x, name="Ax")
677        x_mat_x = math_ops.matmul(array_ops.transpose(x), mat_x, name="xAx")
678        hess_v = gradients_impl._hessian_vector_product(x_mat_x, [x], [v])[0]
679        hess_v_actual = self.evaluate(hess_v)
680      self.assertAllClose(hess_v_value, hess_v_actual)
681
682
683class HessianTest(test_util.TensorFlowTestCase):
684
685  @test_util.run_v1_only("b/120545219")
686  def testHessian1D(self):
687    # Manually compute the Hessian explicitly for a low-dimensional problem
688    # and check that `hessian` matches. Specifically, the Hessian of
689    # f(x) = x^T A x is H = A + A^T.
690    m = 4
691    rng = np.random.RandomState([1, 2, 3])
692    mat_value = rng.randn(m, m).astype("float32")
693    x_value = rng.randn(m).astype("float32")
694    hess_value = mat_value + mat_value.T
695    with self.session(use_gpu=True):
696      mat = constant_op.constant(mat_value)
697      x = constant_op.constant(x_value)
698      x_mat_x = math_ops.reduce_sum(x[:, None] * mat * x[None, :])
699      hess = gradients.hessians(x_mat_x, x)[0]
700      hess_actual = self.evaluate(hess)
701    self.assertAllClose(hess_value, hess_actual)
702
703  @test_util.run_v1_only("b/120545219")
704  def testHessian1D_multi(self):
705    # Test the computation of the hessian with respect to multiple tensors
706    m = 4
707    n = 3
708    rng = np.random.RandomState([1, 2, 3])
709    mat_values = [rng.randn(m, m).astype("float32") for _ in range(n)]
710    x_values = [rng.randn(m).astype("float32") for _ in range(n)]
711    hess_values = [mat_value + mat_value.T for mat_value in mat_values]
712    with self.session(use_gpu=True):
713      mats = [constant_op.constant(mat_value) for mat_value in mat_values]
714      xs = [constant_op.constant(x_value) for x_value in x_values]
715      xs_mats_xs = [
716          math_ops.reduce_sum(x[:, None] * mat * x[None, :])
717          for x, mat in zip(xs, mats)
718      ]
719      hessians = gradients.hessians(xs_mats_xs, xs)
720      hessians_actual = [hess.eval() for hess in hessians]
721    for hess_value, hess_actual in zip(hess_values, hessians_actual):
722      self.assertAllClose(hess_value, hess_actual)
723
724  @test_util.run_v1_only("b/120545219")
725  def testHessianInvalidDimension(self):
726    for shape in [(10, 10), None]:
727      with self.cached_session(use_gpu=True):
728        x = array_ops.placeholder(dtypes.float32, shape)
729        # Expect a ValueError because the dimensions are wrong
730        with self.assertRaises(ValueError):
731          gradients.hessians(x, x)
732
733  @test_util.run_v1_only("b/120545219")
734  def testHessian2D_square_matrix(self):
735    # Manually compute the Hessian explicitly for a low-dimensional problem
736    # and check that `hessian` matches. Specifically, the Hessian of
737    # f(x) = 1/2 * x^T * x is H = constant (block identity matrix)
738    m = 3
739    rng = np.random.RandomState([1, 2, 3])
740    x_value = rng.randn(m, m).astype("float32")
741    with self.session(use_gpu=True):
742      x = constant_op.constant(x_value)
743      x_square = math_ops.reduce_sum(
744          math_ops.matmul(array_ops.transpose(x), x) * 0.5
745      )
746      hess = gradients.hessians(x_square, x)[0]
747      hess_actual = self.evaluate(hess)
748    hess_value = np.bmat([
749        [elem*np.ones((m, m)) for elem in vec]
750        for vec in np.eye(m)
751    ]).astype("float32")
752    self.assertAllEqual((m, m, m, m), hess_actual.shape)
753    self.assertAllClose(hess_value, hess_actual.reshape((m * m, m * m)))
754
755  @test_util.run_v1_only("b/120545219")
756  def testHessian2D_non_square_matrix(self):
757    m = 3
758    n = 4
759    rng = np.random.RandomState([1, 2, 3])
760    x_value = rng.randn(m, n).astype("float32")
761    with self.session(use_gpu=True):
762      x = constant_op.constant(x_value)
763      x_square = math_ops.reduce_sum(
764          math_ops.matmul(array_ops.transpose(x), x) * 0.5
765      )
766      hess = gradients.hessians(x_square, x)[0]
767      hess_actual = self.evaluate(hess)
768    hess_value = np.bmat([
769        [elem*np.ones((n, n)) for elem in vec]
770        for vec in np.eye(m)
771    ]).astype("float32")
772    self.assertAllEqual((m, n, m, n), hess_actual.shape)
773    self.assertAllClose(hess_value, hess_actual.reshape((m * n, m * n)))
774
775
776class IndexedSlicesToTensorTest(test_util.TensorFlowTestCase):
777
778  @test_util.run_v1_only("b/120545219")
779  def testIndexedSlicesToTensor(self):
780    with self.cached_session():
781      np_val = np.random.rand(4, 4, 4, 4).astype(np.float32)
782      c = constant_op.constant(np_val)
783      c_sparse = math_ops._as_indexed_slices(c)
784      self.assertAllEqual(np_val.shape, c_sparse.dense_shape.eval())
785      c_dense = math_ops.multiply(c_sparse, 1.0)
786      self.assertAllClose(np_val, self.evaluate(c_dense))
787
788  @test_util.run_v1_only("b/120545219")
789  def testIndexedSlicesToTensorList(self):
790    with self.cached_session():
791      numpy_list = []
792      dense_list = []
793      sparse_list = []
794      for _ in range(3):
795        np_val = np.random.rand(4, 4, 4, 4).astype(np.float32)
796        c = constant_op.constant(np_val)
797        c_sparse = math_ops._as_indexed_slices(c)
798        numpy_list.append(np_val)
799        dense_list.append(c)
800        sparse_list.append(c_sparse)
801      packed_dense = array_ops.stack(dense_list)
802      packed_sparse = array_ops.stack(sparse_list)
803      self.assertAllClose(packed_dense.eval(), self.evaluate(packed_sparse))
804
805  @test_util.run_v1_only("b/120545219")
806  def testInt64Indices(self):
807    with self.cached_session():
808      np_val = np.random.rand(4, 4, 4, 4).astype(np.float32)
809      c = constant_op.constant(np_val)
810      c_sparse = math_ops._as_indexed_slices(c)
811      c_sparse = ops.IndexedSlices(
812          c_sparse.values,
813          math_ops.cast(c_sparse.indices, dtypes.int64), c_sparse.dense_shape)
814      self.assertAllEqual(np_val.shape, c_sparse.dense_shape.eval())
815      c_dense = math_ops.multiply(c_sparse, 1.0)
816      self.assertAllClose(np_val, self.evaluate(c_dense))
817
818  @test_util.run_v1_only("b/120545219")
819  def testWarnings(self):
820    # TODO(gunan) Reenable after this issue is fixed:
821    # https://github.com/google/protobuf/issues/2812
822    if sys.version_info >= (3, 5):
823      self.skipTest("Skipped test for Python 3.5+")
824
825    # Smaller than the threshold: no warning.
826    c_sparse = ops.IndexedSlices(
827        array_ops.placeholder(dtypes.float32),
828        array_ops.placeholder(dtypes.int32), constant([4, 4, 4, 4]))
829    with warnings.catch_warnings(record=True) as w:
830      math_ops.multiply(c_sparse, 1.0)
831    self.assertEqual(0, len(w))
832
833    # Greater than or equal to the threshold: warning.
834    c_sparse = ops.IndexedSlices(
835        array_ops.placeholder(dtypes.float32),
836        array_ops.placeholder(dtypes.int32), constant([100, 100, 100, 100]))
837    # "always" filter prevents the warning from being suppressed if it was
838    # already triggered in a different test.
839    warnings.simplefilter("always")
840    with warnings.catch_warnings(record=True) as w:
841      math_ops.multiply(c_sparse, 1.0)
842    self.assertEqual(1, len(w))
843    self.assertTrue(
844        "with 100000000 elements. This may consume a large amount of memory." in
845        str(w[0].message))
846
847    # Unknown dense shape: warning.
848    c_sparse = ops.IndexedSlices(
849        array_ops.placeholder(dtypes.float32),
850        array_ops.placeholder(dtypes.int32),
851        array_ops.placeholder(dtypes.int32))
852    with warnings.catch_warnings(record=True) as w:
853      math_ops.multiply(c_sparse, 1.0)
854    self.assertEqual(1, len(w))
855    self.assertTrue(
856        "of unknown shape. This may consume a large amount of memory." in
857        str(w[0].message))
858
859
860class OnlyRealGradientsTest(test_util.TensorFlowTestCase):
861
862  @test_util.run_v1_only("b/120545219")
863  def testRealOnly(self):
864    x = constant_op.constant(7+3j, dtype=dtypes.complex64)
865    y = math_ops.square(x)
866    with self.assertRaisesRegexp(
867        TypeError,
868        r"Gradients of complex tensors must set grad_ys "
869        r"\(y\.dtype = tf\.complex64\)"):
870      gradients.gradients(y, x)
871
872
873class ResourceCondTest(test_util.TensorFlowTestCase):
874
875  @test_util.run_v1_only("b/120545219")
876  def testBasic(self):
877    gamma = resource_variable_ops.ResourceVariable(
878        np.random.random((3,)),
879        dtype="float32", name="gamma")
880
881    inputs = array_ops.ones(shape=(3,), dtype="float32")
882
883    def TestFn():
884      output = inputs + gamma
885      return output
886
887    training = array_ops.placeholder_with_default(True, shape=())
888    output = control_flow_ops.cond(
889        training, TestFn, lambda: inputs)
890
891    loss = output
892
893    grads = gradients.gradients(
894        loss, [gamma])
895    self.assertTrue(None not in grads)
896
897
898class GetDependentVariablesTest(test_util.TensorFlowTestCase):
899
900  def testNoVariables(self):
901    with ops.Graph().as_default():
902      func = lambda x: array_ops.identity(x) + 5.0
903      input_t = constant_op.constant(2.0)
904      result_t = func(input_t)
905      dependent_vars = custom_gradient.get_dependent_variables(
906          [input_t], [result_t])
907
908      # There are no variables.
909      self.assertEqual(dependent_vars, [])
910
911  def testVariablesOutside(self):
912    with ops.Graph().as_default():
913      init = constant_op.constant(100.0)
914      var = variables.Variable(init)
915
916      # The variable is closed over. It should be found.
917      func = lambda x: array_ops.identity(x) + 5.0 + var
918
919      input_t = constant_op.constant(2.0)
920      result_t = func(input_t)
921      dependent_vars = custom_gradient.get_dependent_variables(
922          [input_t], [result_t])
923      self.assertEqual(dependent_vars, [var])
924
925  def testVariableSamePrefix(self):
926    with ops.Graph().as_default():
927      var_name = "my_variable"
928      v_z = variable_scope.get_variable(var_name, shape=())
929      v_o = variable_scope.get_variable(var_name + "_ones", shape=())
930
931      # The variable is closed over. It should be found.
932      func = lambda x: array_ops.identity(x) + 5.0 + v_z + v_o
933
934      input_t = constant_op.constant(2.0)
935      result_t = func(input_t)
936      dependent_vars = custom_gradient.get_dependent_variables(
937          [input_t], [result_t])
938      self.assertEqual(set(dependent_vars), set([v_o, v_z]))
939
940  def testVariablesOutsideButDSeparated(self):
941    with ops.Graph().as_default():
942      init = constant_op.constant(100.0)
943      var = variables.Variable(init)
944
945      # The variable is d-separated by the inputs. It should not be found.
946      input_t = array_ops.identity(var) * 5.0
947
948      func = lambda x: array_ops.identity(x) + 5.0
949      result_t = func(input_t)
950      dependent_vars = custom_gradient.get_dependent_variables(
951          [input_t], [result_t])
952      self.assertEqual(dependent_vars, [])
953
954  def testVariablesOutsideAndNonDifferentiable(self):
955    with ops.Graph().as_default():
956      init = constant_op.constant(100.0, shape=(5,))
957      var = variables.Variable(init, shape=(5,))
958
959      def _Func(x):
960        # non-differentiable dependency on var.
961        # the variable should not be found.
962        y = array_ops.ones_like(var)
963        return array_ops.identity(x) + 5.0 + y
964
965      input_t = constant_op.constant(2.0)
966      result_t = _Func(input_t)
967      dependent_vars = custom_gradient.get_dependent_variables(
968          [input_t], [result_t])
969      self.assertEqual(dependent_vars, [])
970
971  def testVariablesOutsideAndNonTrainable(self):
972    with ops.Graph().as_default():
973      init = constant_op.constant(100.0, shape=(5,))
974
975      # Both variables are used in the function but only the trainable one
976      # should be found.
977      var_trainable = variables.Variable(init, shape=(5,))
978      var_nontrainable = variables.Variable(init, shape=(5,), trainable=False)
979
980      def _Func(x):
981        del x
982        return var_trainable + var_nontrainable
983
984      input_t = constant_op.constant(2.0)
985      result_t = _Func(input_t)
986      dependent_vars = custom_gradient.get_dependent_variables(
987          [input_t], [result_t])
988      self.assertEqual(dependent_vars, [var_trainable])
989
990  def testNesting(self):
991    with ops.Graph().as_default():
992      init = constant_op.constant(100.0, shape=(5,))
993      var = variables.Variable(init, shape=(5,))
994
995      def _Func(inputs):
996        x = inputs["x"]
997        result = array_ops.identity(x) + 5.0 + var
998        return {
999            "y": result
1000        }
1001
1002      input_t = constant_op.constant(2.0)
1003      func_inputs = {
1004          "x": input_t
1005      }
1006      result_t = _Func(func_inputs)
1007
1008      # Ensure we can deal with dictionary input and output.
1009      dependent_vars = custom_gradient.get_dependent_variables(
1010          func_inputs, result_t)
1011      self.assertEqual(dependent_vars, [var])
1012
1013  def testVariablesOutsideAndCustomGradient(self):
1014    with ops.Graph().as_default():
1015      init = constant_op.constant(100.0, shape=(5,))
1016      var = variables.Variable(init, shape=(5,))
1017
1018      @custom_gradient.custom_gradient
1019      def _MyOnesLike(x):
1020        """Dummy version of ones_like which defines a gradient."""
1021
1022        output = array_ops.ones_like(x)
1023
1024        def _Grad(dy):
1025          return array_ops.identity(dy)
1026
1027        return output, _Grad
1028
1029      def _Func(x):
1030        # non-differentiable operation with custom gradient.
1031        # The variable should be found.
1032        y = _MyOnesLike(var)
1033        return array_ops.identity(x) + 5.0 + y
1034
1035      input_t = constant_op.constant(2.0)
1036      result_t = _Func(input_t)
1037      dependent_vars = custom_gradient.get_dependent_variables(
1038          [input_t], [result_t])
1039      self.assertEqual(dependent_vars, [var])
1040
1041
1042class CustomGradientTest(test_util.TensorFlowTestCase):
1043
1044  def testCustomGradientTrivial(self):
1045
1046    @custom_gradient.custom_gradient
1047    def MyIdentity(x):
1048
1049      def Grad(dy):
1050        return [3 * dy]
1051
1052      return x, Grad
1053
1054    with ops.Graph().as_default():
1055      x = constant(3.)
1056      y = MyIdentity(MyIdentity(x))
1057      dy = gradients.gradients(y, x)[0]
1058      with session.Session():
1059        self.assertEqual(9., self.evaluate(dy))
1060
1061  def testCustomGradient(self):
1062
1063    @custom_gradient.custom_gradient
1064    def MyMultiply(x1, x2):
1065      result = x1 * x2
1066
1067      def Grad(dy):
1068        # Switched the ordering here.
1069        return [dy * x1, dy * x2]
1070
1071      return result, Grad
1072
1073    with ops.Graph().as_default():
1074      x1 = constant(3.)
1075      x2 = constant(5.)
1076      y = MyMultiply(x1, x2)
1077      dy = gradients.gradients(y, [x1, x2])
1078      with session.Session() as sess:
1079        self.assertAllEqual([3., 5.], self.evaluate(dy))
1080
1081  def testCustomGradientClass(self):
1082
1083    class Model(object):
1084
1085      @custom_gradient.custom_gradient
1086      def Multiply(self, x1, x2):
1087        result = x1 * x2
1088        grad = lambda dy: (dy * x1, dy * x2)
1089        return result, grad
1090
1091    with ops.Graph().as_default():
1092      x1 = constant(3.)
1093      x2 = constant(5.)
1094      m = Model()
1095      y = m.Multiply(x1, x2)
1096      dy = gradients.gradients(y, [x1, x2])
1097      self.assertAllEqual([3., 5.], self.evaluate(dy))
1098
1099  def testCustomGradientErrors(self):
1100
1101    @custom_gradient.custom_gradient
1102    def F(x):
1103
1104      def Grad(_):
1105        raise RuntimeError("x")
1106
1107      return x, Grad
1108
1109    with ops.Graph().as_default():
1110      x = constant(1.0)
1111      y = F(x)
1112      with self.assertRaises(RuntimeError):
1113        gradients.gradients(y, x)
1114
1115  def testCustomGradientWithVariables(self):
1116
1117    @custom_gradient.custom_gradient
1118    def F(x):
1119      out = core_layers.dense(x, 3, use_bias=False)
1120
1121      def Grad(out_grad, variables=None):  # pylint: disable=redefined-outer-name
1122        self.assertEqual(1, len(variables))
1123        grads = gradients.gradients(out, [x, variables[0]], grad_ys=out_grad)
1124        return grads[0], [array_ops.ones((4, 3))]
1125
1126      return out, Grad
1127
1128    with ops.Graph().as_default():
1129      x = array_ops.ones((2, 4))
1130      with variable_scope.variable_scope("f", use_resource=True) as vs:
1131        y = F(x)
1132        all_vars = vs.global_variables()
1133        assert len(all_vars) == 1
1134      grads = gradients.gradients(y, [x, all_vars[0]])
1135      for g in grads:
1136        self.assertTrue(g is not None)
1137      with session.Session() as sess:
1138        self.evaluate(variables.global_variables_initializer())
1139        dw = sess.run(math_ops.reduce_sum(grads[1]))
1140        self.assertEqual(12., dw)
1141
1142  def testCustomGradientWithVariablesNoFalsePositives(self):
1143
1144    @custom_gradient.custom_gradient
1145    def F(x):
1146      out = core_layers.dense(x, 3, use_bias=False)
1147
1148      def Grad(out_grad, variables=None):  # pylint: disable=redefined-outer-name
1149        self.assertEqual(1, len(variables))
1150        grads = gradients.gradients(out, [x, variables[0]], grad_ys=out_grad)
1151        return grads[0], [array_ops.ones((3, 3))]
1152
1153      return out, Grad
1154
1155    with ops.Graph().as_default():
1156      with variable_scope.variable_scope("f", use_resource=True) as vs:
1157        a = array_ops.ones((2, 4))
1158
1159        # Variabes in these layers shouldn't be picked up by the decorator.
1160        b = core_layers.dense(a, 3, use_bias=False)
1161        c = core_layers.dense(b, 3, use_bias=False)
1162        x = core_layers.dense(b, 3, use_bias=False) + c
1163
1164        # Only the variables used in F.
1165        y = F(x)
1166
1167        all_vars = vs.global_variables()
1168        assert len(all_vars) == 4
1169      grads = gradients.gradients(y, [x] + all_vars)
1170      _, var_grads = grads[0], grads[1:]
1171      for g in grads:
1172        self.assertIsNotNone(g)
1173      with session.Session() as sess:
1174        self.evaluate(variables.global_variables_initializer())
1175        dw = sess.run(math_ops.reduce_sum(var_grads[-1]))
1176        self.assertEqual(9., dw)
1177
1178  def testCustomGradientWithVariablesEager(self):
1179    with context.eager_mode():
1180      layer = core_layers.Dense(4, use_bias=False)
1181
1182      @custom_gradient.custom_gradient
1183      def F(x):
1184        out = layer(x)
1185
1186        def Grad(out_grad, variables=None):  # pylint: disable=redefined-outer-name
1187          del out_grad
1188          self.assertEqual(1, len(variables))
1189          return (array_ops.ones((3, 2)),
1190                  [array_ops.ones((2, 4))])
1191
1192        return out, Grad
1193
1194      x = array_ops.ones((3, 2)) + 2.
1195      with backprop.GradientTape() as tape:
1196        tape.watch(x)
1197        y = F(x)
1198      w, = layer.variables
1199      dx, dw = tape.gradient(y, [x, w])
1200      self.assertEqual(6., math_ops.reduce_sum(dx).numpy())
1201      self.assertEqual(8., math_ops.reduce_sum(dw).numpy())
1202
1203  @test_util.run_v1_only("b/120545219")
1204  def testCustomGradientErrorsWithNonResourceVariables(self):
1205
1206    def F(x, use_resource=False):
1207      with variable_scope.variable_scope("f", use_resource=use_resource):
1208        out = core_layers.dense(x, 4, use_bias=False)
1209
1210      def Grad(out_grad, variables=None):  # pylint: disable=redefined-outer-name
1211        del out_grad
1212        self.assertEqual(1, len(variables))
1213        return (array_ops.ones((3, 2)), [array_ops.ones((2, 4))])
1214
1215      return out, Grad
1216
1217    @custom_gradient.custom_gradient
1218    def FResource(x):
1219      return F(x, use_resource=True)
1220
1221    @custom_gradient.custom_gradient
1222    def FNonResource(x):
1223      return F(x, use_resource=False)
1224
1225    x = array_ops.ones((3, 2)) + 2.
1226
1227    # Wrapping scope has use_resource=True but inner scope sets to False. Fails.
1228    with variable_scope.variable_scope("vs1", use_resource=True):
1229      with self.assertRaisesWithPredicateMatch(TypeError,
1230                                               "must be `ResourceVariable`s"):
1231        FNonResource(x)
1232
1233    # Wrapping scope has use_resource=False but inner scope sets to True.
1234    # Passes.
1235    with variable_scope.variable_scope("vs2", use_resource=False):
1236      FResource(x)
1237
1238  def testWithNumpyInputs(self):
1239    with context.eager_mode():
1240
1241      @custom_gradient.custom_gradient
1242      def F(x):
1243        out = x
1244
1245        def Grad(_):
1246          return (None, None)
1247
1248        return out, Grad
1249
1250      x = np.ones((3, 2), dtype=np.float32)
1251      # Smoke test to ensure numpy inputs are accepted
1252      F(x)
1253
1254  @test_util.run_v1_only("b/120545219")
1255  def testRVGradientsDynamicCond(self):
1256    with self.cached_session():
1257      alpha = resource_variable_ops.ResourceVariable(
1258          np.random.random((1,)),
1259          dtype="float32")
1260
1261      conditional = array_ops.placeholder_with_default(True, shape=())
1262      output = control_flow_ops.cond(
1263          conditional, lambda: alpha * 2, lambda: alpha * 3)
1264
1265      g, = gradients_impl.gradients(output, alpha)
1266      self.evaluate(variables.global_variables_initializer())
1267      self.assertAllEqual(g.eval(), [2.0])
1268      self.assertAllEqual(g.eval(feed_dict={conditional: False}), [3.0])
1269
1270  def testRecursiveCustomGradient(self):
1271    @custom_gradient.custom_gradient
1272    def F(x):
1273      out = core_layers.dense(x, 3, use_bias=False)
1274
1275      def Grad(out_grad, variables=None):  # pylint: disable=redefined-outer-name
1276        self.assertEqual(1, len(variables))
1277        grads = gradients.gradients(out, [x, variables[0]], grad_ys=out_grad)
1278        return grads[0], [array_ops.ones((4, 3))]
1279
1280      return out, Grad
1281
1282    @custom_gradient.custom_gradient
1283    def DoubleF(x):
1284      out = F(x)
1285
1286      def Grad(out_grad, variables=None):  # pylint: disable=redefined-outer-name
1287        self.assertEqual(1, len(variables))
1288        grads = gradients.gradients(out, [x, variables[0]], grad_ys=out_grad)
1289        return grads[0], [array_ops.ones((4, 3))]
1290
1291      return out, Grad
1292    with ops.Graph().as_default():
1293      x = array_ops.ones((2, 4))
1294      with variable_scope.variable_scope("f", use_resource=True) as vs:
1295        y = DoubleF(x)
1296        all_vars = vs.global_variables()
1297        assert len(all_vars) == 1
1298      grads = gradients.gradients(y, [x, all_vars[0]])
1299      for g in grads:
1300        self.assertIsNotNone(g)
1301      with session.Session() as sess:
1302        self.evaluate(variables.global_variables_initializer())
1303        dw = sess.run(math_ops.reduce_sum(grads[1]))
1304        self.assertEqual(12., dw)
1305
1306
1307class TensorListGradientsTest(test_util.TensorFlowTestCase):
1308
1309  def testDefaultGradYs(self):
1310    with ops.Graph().as_default():
1311      tl = list_ops.empty_tensor_list(
1312          element_dtype=dtypes.float32,
1313          element_shape=ops.convert_to_tensor([], dtype=dtypes.int32))
1314      a = constant(1.0)
1315      tl = list_ops.tensor_list_push_back(tl, a)
1316
1317      grad_tl = list_ops.empty_tensor_list(
1318          element_dtype=dtypes.float32,
1319          element_shape=ops.convert_to_tensor([], dtype=dtypes.int32))
1320      grad_tl = list_ops.tensor_list_push_back(tl, constant(5.0))
1321
1322      grad = gradients.gradients(tl, a, grad_ys=grad_tl)[0]
1323      with self.cached_session() as sess:
1324        self.assertEquals(self.evaluate(grad), 5.)
1325
1326
1327class TestKerasModelClass(training.Model):
1328  """A simple tensorflow keras Model class definition."""
1329
1330  def __init__(self, width):
1331    super(TestKerasModelClass, self).__init__()
1332
1333    self.weight = variable_scope.get_variable(
1334        name="test_keras_var",
1335        shape=width,
1336        dtype=dtypes.float32,
1337        trainable=True,
1338        use_resource=True,
1339    )
1340
1341  def call(self, inputs):
1342    return self.weight * inputs
1343
1344
1345class VariablesGradientTest(test_util.TensorFlowTestCase):
1346
1347  def _TestVariablesGradient(self, inputs, test_model, vars_to_grad):
1348    """Returns gradients of `test_model` with respect to `vars_to_grad`."""
1349
1350    test_model_re = custom_gradient.recompute_grad(test_model)
1351
1352    with backprop.GradientTape(persistent=True) as tape:
1353      tape.watch(vars_to_grad)
1354      out_re = test_model_re(inputs)
1355      out = test_model(inputs)
1356
1357    grads_re = tape.gradient(out_re, vars_to_grad)
1358    grads = tape.gradient(out, vars_to_grad)
1359
1360    return grads_re, grads
1361
1362  def _TestFnVariablesGradient(self, inputs, test_fn, vars_to_grad):
1363    """Returns gradients of `test_model` with respect to `vars_to_grad`."""
1364
1365    test_fn_re = custom_gradient.recompute_grad(test_fn)
1366
1367    with backprop.GradientTape(persistent=True) as tape:
1368      tape.watch(vars_to_grad)
1369      out_re = test_fn_re(inputs, vars_to_grad)
1370      out = test_fn(inputs, vars_to_grad)
1371
1372    grads_re = tape.gradient(out_re, vars_to_grad)
1373    grads = tape.gradient(out, vars_to_grad)
1374
1375    return grads_re, grads
1376
1377  @test_util.run_in_graph_and_eager_modes
1378  def testKerasRecompute(self):
1379    """Checks that recompute_grad works for a simple Keras Model."""
1380
1381    test_model = TestKerasModelClass(10)
1382    test_input = constant(np.zeros((10, 10), dtype=np.float32))
1383    self.evaluate(variables.global_variables_initializer())
1384    test_model(test_input)  # Ensures keras model is initialized.
1385    grads_re, grads = self._TestVariablesGradient(test_input, test_model,
1386                                                  test_input)
1387
1388    grads_re = self.evaluate(grads_re)
1389    grads = self.evaluate(grads)
1390    for g, g_re in zip(grads, grads_re):
1391      self.assertAllClose(g, g_re)
1392
1393    grads_re, grads = self._TestVariablesGradient(test_input, test_model,
1394                                                  test_model.variables)
1395
1396    grads_re = self.evaluate(grads_re)
1397    grads = self.evaluate(grads)
1398    for g, g_re in zip(grads, grads_re):
1399      self.assertAllClose(g, g_re)
1400
1401  @test_util.run_in_graph_and_eager_modes
1402  def testFnRecompute(self):
1403    """Checks that recompute_grad works grads of function args."""
1404
1405    def TestFn(inputs, input_vars):
1406      return inputs * input_vars
1407
1408    def TestFnSeq(inputs, input_vars):
1409      return (inputs * input_vars, inputs * input_vars * 2.0)
1410
1411    with variable_scope.variable_scope("test", use_resource=True):
1412      test_var = variable_scope.get_variable(
1413          name="test_var",
1414          shape=10,
1415          trainable=True,
1416      )
1417
1418      test_input = constant(np.zeros((10, 10), dtype=np.float32))
1419
1420      grads_re, grads = self._TestFnVariablesGradient(test_input, TestFn,
1421                                                      test_input)
1422
1423      grads_re = self.evaluate(grads_re)
1424      grads = self.evaluate(grads)
1425      for g, g_re in zip(grads, grads_re):
1426        self.assertAllClose(g, g_re)
1427
1428      grads_re, grads = self._TestFnVariablesGradient(test_input, TestFn,
1429                                                      test_var)
1430      grads_re = self.evaluate(grads_re)
1431      grads = self.evaluate(grads)
1432      for g, g_re in zip(grads, grads_re):
1433        self.assertAllClose(g, g_re)
1434
1435      # Regression test for wrapping sequence outputting functions.
1436      grads_re, grads = self._TestFnVariablesGradient(test_input, TestFnSeq,
1437                                                      test_input)
1438      grads_re = self.evaluate(grads_re)
1439      grads = self.evaluate(grads)
1440      for g, g_re in zip(grads, grads_re):
1441        self.assertAllClose(g, g_re)
1442
1443      grads_re, grads = self._TestFnVariablesGradient(test_input, TestFnSeq,
1444                                                      test_var)
1445      grads_re = self.evaluate(grads_re)
1446      grads = self.evaluate(grads)
1447      for g, g_re in zip(grads, grads_re):
1448        self.assertAllClose(g, g_re)
1449
1450  @test_util.deprecated_graph_mode_only
1451  def testFnRecomputeWithScopeGradientTape(self):
1452    """Checks that recompute_grad works with var scope and GradientTape."""
1453
1454    def TestFn(input_t):
1455      with variable_scope.variable_scope("inner_scope"):
1456        test_var = variable_scope.get_variable(
1457            name="test_var",
1458            shape=10,
1459            trainable=True,
1460        )
1461        return input_t * test_var
1462
1463    test_input_t = constant(np.zeros((10, 10), dtype=np.float32))
1464
1465    with variable_scope.variable_scope(
1466        "output_scope", reuse=variable_scope.AUTO_REUSE, use_resource=True):
1467      test_fn_re = custom_gradient.recompute_grad(TestFn)
1468
1469      with backprop.GradientTape(persistent=True) as tape:
1470        out_re = test_fn_re(test_input_t)
1471        out = TestFn(test_input_t)
1472
1473    grads_re = tape.gradient(out_re, variables.trainable_variables())
1474    grads = tape.gradient(out, variables.trainable_variables())
1475
1476    grads_re = self.evaluate(grads_re)
1477    grads = self.evaluate(grads)
1478    for g, g_re in zip(grads, grads_re):
1479      self.assertAllClose(g, g_re)
1480      self.assertAllClose(g, g_re)
1481
1482  @test_util.deprecated_graph_mode_only
1483  def testFnRecomputeWithScopeGradients(self):
1484    """Checks that recompute_grad works with var scope and gradients(..)."""
1485
1486    def TestFn(input_t):
1487      with variable_scope.variable_scope("inner_scope"):
1488        test_var = variable_scope.get_variable(
1489            name="test_var",
1490            shape=10,
1491            trainable=True,
1492        )
1493        return input_t * test_var
1494
1495    test_input_t = constant(np.zeros((10, 10), dtype=np.float32))
1496
1497    with variable_scope.variable_scope(
1498        "output_scope", reuse=variable_scope.AUTO_REUSE, use_resource=True):
1499      test_fn_re = custom_gradient.recompute_grad(TestFn)
1500      out_re = test_fn_re(test_input_t)
1501      out = TestFn(test_input_t)
1502
1503    grads_re = gradients.gradients(out_re, variables.trainable_variables())
1504    grads = gradients.gradients(out, variables.trainable_variables())
1505
1506    grads_re = self.evaluate(grads_re)
1507    grads = self.evaluate(grads)
1508    for g, g_re in zip(grads, grads_re):
1509      self.assertAllClose(g, g_re)
1510      self.assertAllClose(g, g_re)
1511
1512  @test_util.run_in_graph_and_eager_modes
1513  def testFnRecomputeSameTensor(self):
1514    """Check recompute_grad when wrapped f called as f(x, x) - b/147369366."""
1515
1516    def TestFnMul(x, y):
1517      return x * y
1518
1519    def TestFnSingleVar(x, y):
1520      # pylint: disable=unused-argument
1521      return x
1522
1523    with variable_scope.variable_scope("test", use_resource=True):
1524      x = array_ops.ones((10))
1525
1526      grads_re, grads = self._TestFnVariablesGradient(x, TestFnMul,
1527                                                      x)
1528      grads_re = self.evaluate(grads_re)
1529      grads = self.evaluate(grads)
1530      for g, g_re in zip(grads, grads_re):
1531        self.assertAllClose(g, g_re)
1532
1533      grads_re, grads = self._TestFnVariablesGradient(x, TestFnSingleVar,
1534                                                      x)
1535      grads_re = self.evaluate(grads_re)
1536      grads = self.evaluate(grads)
1537      for g, g_re in zip(grads, grads_re):
1538        self.assertAllClose(g, g_re)
1539
1540
1541class GradPassThroughTest(test_util.TensorFlowTestCase):
1542
1543  @test_util.run_v1_only("b/120545219")
1544  def test_gradients_v1(self):
1545    x = variable_scope.get_variable(
1546        name="x", shape=(), initializer=init_ops.constant_initializer(1.0),
1547        use_resource=True)
1548    z = variable_scope.get_variable(
1549        name="z", shape=(), initializer=init_ops.constant_initializer(3.0),
1550        use_resource=True)
1551
1552    # Verify that assign op is not differentiable
1553    y = state_ops.assign(x, z**2)
1554    grads = gradients.gradients(y, z)
1555    self.assertIsNone(grads[0])
1556
1557    # Verify that when the (non differentiable) assign op is wrapped with
1558    # grad_pass_through, gradients are correctly forwarded to the inputs.
1559    # Form an input as quadratic function of variable z and check that the
1560    # gradient of output wrt to z is correct.
1561    y = custom_gradient.grad_pass_through(
1562        lambda v: state_ops.assign(x, v))(z**2)
1563    grads = gradients.gradients(y, z)
1564    with self.cached_session() as sess:
1565      sess.run(variables.global_variables_initializer())
1566      self.assertAllClose(grads[0].eval(), 6.0)
1567
1568    # Verify that variables involved in the wrapped op do not receive gradients.
1569    y = custom_gradient.grad_pass_through(lambda v: x * v)(z)
1570    grads = gradients.gradients(y, x)
1571    self.assertIsNone(grads[0])
1572
1573  @test_util.run_v2_only
1574  def test_gradients_v2(self):
1575    x = variables.Variable(1.0, name="x")
1576    z = variables.Variable(3.0, name="z")
1577
1578    # Verify that assign op is not differentiable
1579    with backprop.GradientTape() as tape:
1580      y = x.assign(z**2)
1581    grads = tape.gradient(y, z)
1582    self.assertIsNone(grads)
1583
1584    # Verify that when the (non differentiable) assign op is wrapped with
1585    # grad_pass_through, gradients are correctly forwarded to the inputs.
1586    # Form an input as quadratic function of variable z and check that the
1587    # gradient of output wrt to z is correct.
1588    with backprop.GradientTape() as tape:
1589      y = custom_gradient.grad_pass_through(x.assign)(z**2)
1590    grads = tape.gradient(y, z)
1591    self.assertAllClose(grads, 6.0)
1592
1593    # Verify that variables involved in the wrapped op do not receive gradients.
1594    with backprop.GradientTape() as tape:
1595      y = custom_gradient.grad_pass_through(lambda v: x * v)(z)
1596    grads = tape.gradient(y, x)
1597    self.assertIsNone(grads)
1598
1599
1600if __name__ == "__main__":
1601  googletest.main()
1602