• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for tensorflow.ops.gradients."""
16
17import sys
18import warnings
19
20from absl.testing import parameterized
21import numpy as np
22from tensorflow.python.client import session
23from tensorflow.python.eager import backprop
24from tensorflow.python.eager import context
25from tensorflow.python.eager import function
26from tensorflow.python.framework import constant_op
27from tensorflow.python.framework import dtypes
28from tensorflow.python.framework import function as framework_function
29from tensorflow.python.framework import indexed_slices
30from tensorflow.python.framework import ops
31from tensorflow.python.framework import tensor_spec
32from tensorflow.python.framework import test_ops
33from tensorflow.python.framework import test_util
34from tensorflow.python.framework.constant_op import constant
35from tensorflow.python.layers import core as core_layers
36from tensorflow.python.ops import array_grad  # pylint: disable=unused-import
37from tensorflow.python.ops import array_ops
38from tensorflow.python.ops import control_flow_grad  # pylint: disable=unused-import
39from tensorflow.python.ops import control_flow_ops
40from tensorflow.python.ops import custom_gradient
41from tensorflow.python.ops import data_flow_grad  # pylint: disable=unused-import
42from tensorflow.python.ops import data_flow_ops  # pylint: disable=unused-import
43from tensorflow.python.ops import functional_ops  # pylint: disable=unused-import
44from tensorflow.python.ops import gradient_checker_v2
45from tensorflow.python.ops import gradients
46from tensorflow.python.ops import gradients_impl
47from tensorflow.python.ops import init_ops
48from tensorflow.python.ops import list_ops
49from tensorflow.python.ops import math_grad  # pylint: disable=unused-import
50from tensorflow.python.ops import math_ops
51from tensorflow.python.ops import nn_grad  # pylint: disable=unused-import
52from tensorflow.python.ops import resource_variable_ops
53from tensorflow.python.ops import state_grad  # pylint: disable=unused-import
54from tensorflow.python.ops import state_ops
55from tensorflow.python.ops import tensor_array_grad  # pylint: disable=unused-import
56from tensorflow.python.ops import tensor_array_ops
57from tensorflow.python.ops import unconnected_gradients
58from tensorflow.python.ops import variable_scope
59from tensorflow.python.ops import variables
60from tensorflow.python.ops.nn_ops import bias_add
61from tensorflow.python.platform import googletest
62from tensorflow.python.util import nest
63
64
65class GradientsTest(test_util.TensorFlowTestCase, parameterized.TestCase):
66
67  def testGradients(self):
68    with ops.Graph().as_default():
69      inp = constant(1.0, shape=[32, 100], name="in")
70      w = constant(1.0, shape=[100, 10], name="w")
71      b = constant(1.0, shape=[10], name="b")
72      xw = math_ops.matmul(inp, w, name="xw")
73      h = bias_add(xw, b, name="h")
74      w_grad = gradients.gradients(h, w)[0]
75    self.assertEqual("MatMul", w_grad.op.type)
76    self.assertEqual(w_grad.op._original_op, xw.op)
77    self.assertTrue(w_grad.op.get_attr("transpose_a"))
78    self.assertFalse(w_grad.op.get_attr("transpose_b"))
79
80  def testUnusedOutput(self):
81    with ops.Graph().as_default():
82      w = constant(1.0, shape=[2, 2])
83      x = constant(1.0, shape=[2, 2])
84      wx = math_ops.matmul(w, x)
85      split_wx = array_ops.split(value=wx, num_or_size_splits=2, axis=0)
86      c = math_ops.reduce_sum(split_wx[1])
87      gw = gradients.gradients(c, [w])[0]
88    self.assertEqual("MatMul", gw.op.type)
89
90  def testColocateGradients(self):
91    with ops.Graph().as_default() as g:
92      w = constant(1.0, shape=[1, 1])
93      x = constant(1.0, shape=[1, 2])
94      with g.device("/device:GPU:0"):
95        wx = math_ops.matmul(w, x)
96      gw = gradients.gradients(wx, [w], colocate_gradients_with_ops=True)[0]
97    self.assertEqual(gw.op.colocation_groups(), wx.op.colocation_groups())
98
99  def testColocateGradientsWithAggregation(self):
100    with ops.Graph().as_default() as g:
101      with g.device("/device:GPU:1"):
102        w = constant(1.0, shape=[1, 1])
103      x = constant(1.0, shape=[1, 2])
104      y = constant(1.0, shape=[1, 2])
105      wx = math_ops.matmul(w, x)
106      wy = math_ops.matmul(w, y)
107      with g.device("/device:GPU:0"):
108        z = wx + wy
109
110      gw1 = gradients.gradients(z, [w], colocate_gradients_with_ops=True)[0]
111      self.assertEqual(gw1.op.colocation_groups(), wx.op.colocation_groups())
112
113      gw2 = gradients.gradients(z, [w], colocate_gradients_with_ops=False)[0]
114      self.assertNotEqual(wx.op.colocation_groups(), gw2.op.colocation_groups())
115
116  def testColocateGradientsWithAggregationInMultipleDevices(self):
117    with ops.Graph().as_default() as g:
118      with g.device("/device:GPU:1"):
119        w = constant(1.0, shape=[1, 1])
120      x = constant(1.0, shape=[1, 2])
121      y = constant(1.0, shape=[1, 2])
122      with g.device("/task:1"):
123        wx = math_ops.matmul(w, x)
124      with g.device("/task:2"):
125        wy = math_ops.matmul(w, y)
126      with g.device("/device:GPU:0"):
127        z = wx + wy
128
129      gw1 = gradients.gradients(z, [w], colocate_gradients_with_ops=True)[0]
130      self.assertEqual(gw1.op.colocation_groups(), w.op.colocation_groups())
131
132      gw2 = gradients.gradients(z, [w], colocate_gradients_with_ops=False)[0]
133      self.assertNotEqual(w.op.colocation_groups(), gw2.op.colocation_groups())
134
135  def testColocateGradientsWithGateGradients(self):
136    if not test_util.is_gpu_available():
137      self.skipTest("No GPU available")
138    with ops.Graph().as_default() as g:
139      with g.device("/device:CPU:0"):
140        x = constant(1.0, shape=[1, 1])
141        y = constant(1.0, shape=[1, 1])
142        s = x + y
143      with g.device("/device:GPU:0"):
144        z = math_ops.reduce_sum(s)
145
146      gz_x = gradients.gradients(z, [x], colocate_gradients_with_ops=True,
147                                 gate_gradients=True)[0]
148
149      # Make sure the placer doesn't complain.
150      self.evaluate(gz_x)
151
152  def testBoundaryStop(self):
153    # Test that we don't differentiate 'x'. The gradient function for 'x' is
154    # set explicitly to None so we will get an exception if the gradient code
155    # tries to differentiate 'x'.
156    with ops.Graph().as_default():
157      c = constant(1.0)
158      x = array_ops.identity(c)
159      y = x + 1.0
160      z = y + 1
161      grads = gradients.gradients(z, [x])
162      self.assertTrue(all(x is not None for x in grads))
163
164  @test_util.run_v1_only("b/120545219")
165  def testBoundaryContinue(self):
166    # Test that we differentiate both 'x' and 'y' correctly when x is a
167    # predecessor of y.
168    with self.cached_session():
169      x = constant(1.0)
170      y = x * 2.0
171      z = y * 3.0
172      grads = gradients.gradients(z, [x, y])
173      self.assertTrue(all(x is not None for x in grads))
174      self.assertEqual(6.0, grads[0].eval())
175
176  @test_util.run_v1_only("b/120545219")
177  def testAggregationMethodAccumulateN(self):
178    with self.cached_session():
179      x = constant(1.0)
180      y = x * 2.0
181      z = y + y + y + y + y + y + y + y + y + y
182      grads = gradients.gradients(
183          z, [x, y],
184          aggregation_method=gradients.AggregationMethod.
185          EXPERIMENTAL_ACCUMULATE_N)
186      self.assertTrue(all(x is not None for x in grads))
187      self.assertEqual(20.0, grads[0].eval())
188      self.assertEqual(10.0, grads[1].eval())
189
190  @test_util.run_v1_only("b/120545219")
191  def testAggregationMethodAddN(self):
192    with self.cached_session():
193      x = constant(1.0)
194      y = x * 2.0
195      z = y + y + y + y + y + y + y + y + y + y
196      grads = gradients.gradients(
197          z, [x, y], aggregation_method=gradients.AggregationMethod.ADD_N)
198      self.assertTrue(all(x is not None for x in grads))
199      self.assertEqual(20.0, grads[0].eval())
200      self.assertEqual(10.0, grads[1].eval())
201
202  @test_util.run_v1_only("b/120545219")
203  def testAggregationMethodTree(self):
204    with self.cached_session():
205      x = constant(1.0)
206      y = x * 2.0
207      z = y + y + y + y + y + y + y + y + y + y
208      grads = gradients.gradients(
209          z, [x, y],
210          aggregation_method=gradients.AggregationMethod.EXPERIMENTAL_TREE)
211      self.assertTrue(all(x is not None for x in grads))
212      self.assertEqual(20.0, grads[0].eval())
213      self.assertEqual(10.0, grads[1].eval())
214
215  def testNoGradientForStringOutputs(self):
216    with ops.Graph().as_default():
217
218      def _TestOpGrad(_, float_grad, string_grad):
219        """Gradient function for TestStringOutput."""
220        self.assertEqual(float_grad.dtype, dtypes.float32)
221        self.assertFalse(string_grad)
222        return float_grad
223
224      ops.RegisterGradient("TestStringOutput")(_TestOpGrad)
225
226      c = constant(1.0)
227      x, _ = test_ops.test_string_output(c)
228      z = x * 2.0
229      w = z * 3.0
230      grads = gradients.gradients(z, [c])
231      self.assertIsInstance(grads[0], ops.Tensor)
232      grads = gradients.gradients(w, [c])
233      self.assertIsInstance(grads[0], ops.Tensor)
234
235  def testNoGradientForStringOutputsWithOpNamespace(self):
236    with ops.Graph().as_default():
237
238      def _TestOpGrad(_, float_grad, string_grad):
239        """Gradient function for TestStringOutput."""
240        self.assertEqual(float_grad.dtype, dtypes.float32)
241        self.assertFalse(string_grad)
242        return float_grad
243
244      ops.RegisterGradient("Namespace>TestStringOutput")(_TestOpGrad)
245
246      c = constant(1.0)
247      x, _ = test_ops.namespace_test_string_output(c)
248      z = x * 2.0
249      w = z * 3.0
250      grads = gradients.gradients(z, [c])
251      self.assertIsInstance(grads[0], ops.Tensor)
252      grads = gradients.gradients(w, [c])
253      self.assertIsInstance(grads[0], ops.Tensor)
254
255  def testSingletonIndexedSlices(self):
256    with ops.Graph().as_default():
257      x = array_ops.placeholder(dtypes.float32)
258      y = array_ops.identity(x)
259      dy = indexed_slices.IndexedSlices(
260          array_ops.placeholder(dtypes.float32),
261          array_ops.placeholder(dtypes.int32))
262      dx, = gradients.gradients(y, x, grad_ys=dy)
263      # The IndexedSlices gradient of tf.identity is the identity map.
264      with self.cached_session() as sess:
265        vdx, vdy = sess.run(
266            [dx, dy], feed_dict={x: [1.0], dy.indices: [0], dy.values: [2.0]})
267      self.assertEqual(vdx, vdy)
268
269  @test_util.run_v1_only("b/120545219")
270  def testNonDifferentiableSwitchInWhileLoop(self):
271    with ops.Graph().as_default():
272      v = array_ops.placeholder(dtypes.float32, [])
273
274      def _Step(i, a, ta):
275        a += math_ops.cast(v, dtypes.int32)
276        return (i + 1, a, ta.write(i, a))
277
278      n = 4
279      i, _, ta = control_flow_ops.while_loop(
280          lambda i, *_: i < n,
281          _Step, [0, 0, tensor_array_ops.TensorArray(
282              dtypes.int32, size=n)])
283      target = ta.read(i - 1)
284      grad, = gradients.gradients(target, v)
285      self.assertIsNone(grad)
286
287  def testVariableReadValueGradient(self):
288    with ops.Graph().as_default():
289      init = constant_op.constant(100.0)
290      var = variables.Variable(init)
291      gradient = gradients.gradients(var.read_value(), var)
292      self.assertIsNotNone(gradient)
293
294  @parameterized.parameters(dtypes.float32, dtypes.float64)
295  def testVariableDefaultGrad(self, dtype):
296    with ops.Graph().as_default():
297      init = constant_op.constant(100.0, dtype=dtype)
298      var = variables.Variable(init)
299      dummy_const = constant_op.constant(0.0)
300      gradient = gradients.gradients(
301          dummy_const,
302          var,
303          unconnected_gradients=unconnected_gradients.UnconnectedGradients.ZERO
304      )[0]
305      self.assertEqual(gradient.dtype, dtype)
306      self.assertIsNotNone(gradient)
307
308  def testVariableAsGraphElementGradient(self):
309    with ops.Graph().as_default() as graph:
310      init = constant_op.constant(100.0)
311      var = variables.Variable(init)
312      gradient = gradients.gradients(graph.as_graph_element(var), var)
313      self.assertIsNotNone(gradient)
314
315  @test_util.run_v1_only("b/120545219")
316  def testVariableRefGradient(self):
317    with ops.Graph().as_default():
318      init = constant_op.constant(100.0)
319      var = variables.VariableV1(init)
320      gradient = gradients.gradients(var._ref(), var)
321      self.assertIsNotNone(gradient)
322
323  @test_util.run_v1_only("b/120545219")
324  def testDependentYs(self):
325    with self.cached_session():
326      x = constant_op.constant(3.0)
327      y = math_ops.square(x)
328      y1 = math_ops.square(y)
329      y2 = math_ops.square(y1)
330      g = gradients.gradients([y, y2], x)
331      self.assertAllClose(17502.0, g[0])
332      g = gradients.gradients(y + y2, x)
333      self.assertAllClose(17502.0, g[0])
334      z = array_ops.identity(y)
335      z2 = array_ops.identity(y2)
336      g = gradients.gradients([z, z2], x)
337      self.assertAllClose(17502.0, g[0])
338
339  @test_util.run_v1_only("b/120545219")
340  def testPartialDerivatives(self):
341    with self.cached_session():
342      x = constant_op.constant(1.)
343      y = 2 * x
344      z = x + y
345      totalg = gradients.gradients(z, [x, y])
346      self.assertEqual([3.0, 1.0], [g.eval() for g in totalg])
347      partialg = gradients.gradients(z, [x, y], stop_gradients=[x, y])
348      self.assertEqual([1.0, 1.0], [g.eval() for g in partialg])
349
350  @test_util.run_v1_only("b/120545219")
351  def testStopGradients(self):
352    def _MakeGraph(rng, stop_gradients=()):
353      def _FunctionOf(xs, k=3):
354        return ops.convert_to_tensor(
355            sum(math_ops.matmul(rng.rand(k, k), x) for x in xs)
356            + rng.rand(k, k))
357
358      a = _FunctionOf([])
359      if "a" in stop_gradients: a = array_ops.stop_gradient(a)
360      b = _FunctionOf([a])
361      if "b" in stop_gradients: b = array_ops.stop_gradient(b)
362      c = _FunctionOf([a, b])
363      if "c" in stop_gradients: c = array_ops.stop_gradient(c)
364      d = _FunctionOf([b, c])
365      if "d" in stop_gradients: d = array_ops.stop_gradient(d)
366      return dict(a=a, b=b, c=c, d=d)
367
368    def _Gradients(ys, xs, **kwargs):
369      dydxs = gradients.gradients(ys, xs, **kwargs)
370      dydxs = [0. * x if dydx is None else dydx
371               for x, dydx in zip(xs, dydxs)]
372      return dydxs
373
374    seed = np.random.randint(1000)
375    cases = []
376    subsets = [""] + "a b c d ab ac ad bc bd cd abc abd acd bcd abcd".split()
377    graph = _MakeGraph(np.random.RandomState(seed))
378    for constants in subsets:
379      graph_with_stops = _MakeGraph(np.random.RandomState(seed), constants)
380      for variables_ in subsets:
381        # compute the gradient when stopped using tf.stop_gradients
382        grad1 = _Gradients([graph_with_stops["d"]],
383                           [graph_with_stops[v] for v in variables_])
384        # compute the gradient when stopped using the stop_gradients kwarg
385        grad2 = _Gradients([graph["d"]],
386                           [graph[v] for v in variables_],
387                           stop_gradients=[graph[v] for v in constants])
388        cases.append(dict(grad1=grad1, grad2=grad2,
389                          constants=constants, variables=variables_))
390
391    # evaluate all tensors in one call to session.run for speed
392    with self.cached_session() as sess:
393      results = sess.run([(case["grad1"], case["grad2"]) for case in cases])
394
395    for (npgrad1, npgrad2), case in zip(results, cases):
396      for a, b in zip(npgrad1, npgrad2):
397        np.testing.assert_allclose(a, b)
398
399  def testUnconnectedGradientsNoneUnconnectedGradients(self):
400    with ops.Graph().as_default():
401      x = constant(1.0, shape=[2, 2])
402      y = constant(3.0, shape=[3, 1])
403      grad = gradients.gradients(
404          [y], [x], unconnected_gradients="none")
405    self.assertIsNone(grad[0])
406
407  def testUnconnectedGradientsZerosUnconnectedGradients(self):
408    with ops.Graph().as_default():
409      x = constant(1.0, shape=[2, 2])
410      y = constant(3.0, shape=[3, 1])
411      grads = gradients.gradients(
412          [y], [x], unconnected_gradients="zero")
413
414      self.assertAllEqual([[0.0, 0.0], [0.0, 0.0]], self.evaluate(grads)[0])
415
416  def testUnconnectedGradientsZeroConnectedGradients(self):
417    with ops.Graph().as_default():
418      x = constant(1.0)
419      y = x * 3.0
420      grad = gradients.gradients(
421          [y], [x], unconnected_gradients="zero")
422
423      self.assertEqual(3.0, self.evaluate(grad)[0])
424
425  def testUnknownUnconnectedGradientsValueGiven(self):
426    with ops.Graph().as_default():
427      x = constant(1.0)
428      y = constant(1.0)
429      with self.assertRaisesRegex(
430          ValueError, "Unknown value for unconnected_gradients: 'nonsense'"):
431        gradients.gradients([y], [x], unconnected_gradients="nonsense")
432
433  @parameterized.parameters(unconnected_gradients.UnconnectedGradients.ZERO,
434                            unconnected_gradients.UnconnectedGradients.NONE)
435  def testUnconnectedOpWithMultipleOutputs(self, unconnected_gradients_val):
436    with ops.Graph().as_default():
437      #  a    b
438      #  |    |
439      # IdentityN
440      #  |    |
441      #  c    d
442      #  |
443      # Identity
444      #  |
445      #  e
446      a = constant_op.constant(1.0)
447      b = constant_op.constant(1.0)
448      c, d = array_ops.identity_n([a, b])
449      e = array_ops.identity(c)
450      # The aggregated grads for the IdentityN node would look like
451      # [Tensor, None]. We expect this None to be converted to zeros.
452      output = gradients.gradients(
453          e, d, unconnected_gradients=unconnected_gradients_val)
454      if (unconnected_gradients_val ==
455          unconnected_gradients.UnconnectedGradients.ZERO):
456        self.assertIsNotNone(output[0])
457      else:
458        self.assertIsNone(output[0])
459
460  @parameterized.parameters(unconnected_gradients.UnconnectedGradients.ZERO,
461                            unconnected_gradients.UnconnectedGradients.NONE)
462  def testUnconnectedOpWithMultipleOutputsStopGradient(
463      self, unconnected_gradients_val):
464    with ops.Graph().as_default():
465      #  a    b
466      #  |    |
467      # IdentityN
468      #  |    |
469      #  c    d
470      #  |    |
471      #  SG   |
472      #  |    |
473      #   \  /
474      #    +
475      #    e
476      a = constant_op.constant(1.0)
477      b = constant_op.constant(1.0)
478      c, d = array_ops.identity_n([a, b])
479      e = array_ops.stop_gradient(c) + d
480      # The aggregated grads for the IdentityN node would look like
481      # [None, Tensor]. We expect this None to be converted to zeros.
482      output = gradients.gradients(
483          e, c, unconnected_gradients=unconnected_gradients_val)
484      if (unconnected_gradients_val ==
485          unconnected_gradients.UnconnectedGradients.ZERO):
486        self.assertIsNotNone(output[0])
487      else:
488        self.assertIsNone(output[0])
489
490
491class FunctionGradientsTest(test_util.TensorFlowTestCase):
492
493  @classmethod
494  def XSquarePlusB(cls, x, b):
495    return x * x + b
496
497  @classmethod
498  def XSquarePlusBGradient(cls, x, b, g):
499    # Perturb gradients (multiply by 2), so we can test that this was called.
500    g *= 2.0
501    return g * 2.0 * x, g
502
503  @classmethod
504  def _PythonGradient(cls, op, grad):
505    # Perturb gradients (multiply by 3), so we can test that this was called.
506    grad *= 3.0
507    return grad * op.inputs[0] * 2.0, grad
508
509  @classmethod
510  def _GetFunc(cls, **kwargs):
511    return framework_function.Defun(dtypes.float32, dtypes.float32, **
512                                    kwargs)(cls.XSquarePlusB)
513
514  def _GetFuncGradients(self, f, x_value, b_value):
515    x = constant_op.constant(x_value, name="x")
516    b = constant_op.constant(b_value, name="b")
517
518    y = f(x, b)
519    grads = gradients.gradients(y, [x, b])
520
521    return self.evaluate(grads)
522
523  def testFunctionGradientsBasic(self):
524    g = ops.Graph()
525    with g.as_default():
526      f = self._GetFunc()
527      # Get gradients (should add SymbolicGradient node for function).
528      grads = self._GetFuncGradients(f, [2.0], [1.0])
529      self.assertAllEqual([4.0], grads[0])
530      self.assertAllEqual([1.0], grads[1])
531
532  def testFunctionGradientsComposition(self):
533    with ops.Graph().as_default():
534      f = self._GetFunc()
535      x = constant_op.constant([2.0], name="x")
536      b1 = constant_op.constant([1.0], name="b1")
537      b2 = constant_op.constant([1.0], name="b2")
538
539      y = f(f(x, b1), b2)
540      # Build gradient graph (should add SymbolicGradient node for function).
541      grads = gradients.gradients(y, [x, b1])
542
543      self.assertAllEqual([40.0], self.evaluate(grads)[0])
544      self.assertAllEqual([10.0], self.evaluate(grads)[1])
545
546  def testFunctionGradientsWithGradFunc(self):
547    g = ops.Graph()
548    with g.as_default():
549      grad_func = framework_function.Defun(dtypes.float32, dtypes.float32,
550                                           dtypes.float32)(
551                                               self.XSquarePlusBGradient)
552      f = self._GetFunc(grad_func=grad_func)
553      # Get gradients (should add SymbolicGradient node for function, which
554      # uses the grad_func above, which multiplies all gradients by 2).
555      grads = self._GetFuncGradients(f, [2.0], [1.0])
556      self.assertAllEqual([4.0 * 2], grads[0])
557      self.assertAllEqual([1.0 * 2], grads[1])
558
559  def testFunctionGradientWithRegistration(self):
560    g = ops.Graph()
561    with g.as_default():
562      f = self._GetFunc(python_grad_func=self._PythonGradient)
563      # Get gradients, using the python gradient function. It multiplies the
564      # gradients by 3.
565      grads = self._GetFuncGradients(f, [2.0], [1.0])
566      self.assertAllEqual([4.0 * 3], grads[0])
567      self.assertAllEqual([1.0 * 3], grads[1])
568
569  def testFunctionGradientWithGradFuncAndRegistration(self):
570    g = ops.Graph()
571    with g.as_default():
572      grad_func = framework_function.Defun(dtypes.float32, dtypes.float32,
573                                           dtypes.float32)(
574                                               self.XSquarePlusBGradient)
575      with self.assertRaisesRegex(ValueError, "Gradient defined twice"):
576        f = self._GetFunc(
577            grad_func=grad_func, python_grad_func=self._PythonGradient)
578        f.add_to_graph(ops.Graph())
579
580  def testGradientWrtCaptured(self):
581    with ops.Graph().as_default():
582      x = constant_op.constant(1.0, name="x")
583
584      @function.defun()
585      def Foo():
586        y = math_ops.multiply(x, 2.0, name="y")
587        g = gradients_impl.gradients(y, x)
588        return g[0]
589
590      f = Foo()
591
592      self.assertEqual(self.evaluate(f), 2.0)
593
594  def testGradientOfCaptured(self):
595    with ops.Graph().as_default():
596      x = constant_op.constant(1.0, name="x")
597      y = math_ops.multiply(x, 2.0, name="y")
598
599      @framework_function.Defun()
600      def Foo():
601        g = gradients_impl.gradients(y, x)
602        return g[0]
603
604      f = Foo()
605
606      self.assertEqual(self.evaluate(f), 2.0)
607
608  def testCapturedResourceVariable(self):
609    with ops.Graph().as_default():
610      var = resource_variable_ops.ResourceVariable(1.0, name="var")
611
612      @function.defun()
613      def Foo():
614        y = math_ops.multiply(var, 2.0, name="y")
615        g = gradients_impl.gradients(y, var)
616        return g[0]
617
618      f = Foo()
619
620      self.evaluate(variables.global_variables_initializer())
621      self.assertEqual(self.evaluate(f), 2.0)
622
623  def testCapturedNested(self):
624    with ops.Graph().as_default():
625      x1 = constant_op.constant(1.0, name="x1")
626      x2 = constant_op.constant(2.0, name="x2")
627      x3 = math_ops.multiply(x1, x2, name="x3")
628
629      @function.defun()
630      def Outer():
631        outer1 = array_ops.identity(x1, name="outer1")
632
633        @function.defun()
634        def Inner():
635          inner1 = array_ops.identity(outer1, name="inner1")
636          inner2 = array_ops.identity(x2, name="inner2")
637          inner3 = array_ops.identity(x3, name="inner3")
638          return gradients_impl.gradients([inner1, inner2, inner3, x1],
639                                          [x1, x2])
640
641        return Inner()
642
643      x1_grad, x2_grad = Outer()
644
645      # 1.0 + None + 2.0 + 1.0 = 4.0
646      self.assertEqual(self.evaluate(x1_grad), 4.0)
647      # None + 1.0 + 1.0 + None = 2.0
648      self.assertEqual(self.evaluate(x2_grad), 2.0)
649
650  def testCapturedFromFunction(self):
651    with ops.Graph().as_default():
652      x = constant_op.constant(1.0, name="x")
653
654      @function.defun()
655      def Outer():
656        y = math_ops.multiply(x, 2.0, name="y")
657
658        @function.defun()
659        def Inner():
660          z = math_ops.multiply(y, 3.0, name="z")
661          g = gradients_impl.gradients(z, y)
662          return g[0]
663
664        return Inner()
665
666      z_grad = Outer()
667
668      self.assertEqual(self.evaluate(z_grad), 3.0)
669
670  def testCapturedEagerTensors(self):
671    # Test that we can handle captured eager tensors unrelated to the gradient
672    # computation (i.e. we need to ignore them).
673    # TODO(skyewm): make it an error if you try to take the gradient wrt a
674    # captured EagerTensor
675    with context.eager_mode():
676      c = constant_op.constant(2.0, name="c")
677
678      @function.defun
679      def Foo():
680        x = constant_op.constant(10.0, name="x")
681        y = math_ops.multiply(x, c, name="y")
682        # Regression test for b/122564611.
683        z = math_ops.multiply(c, y, name="z")
684        g = gradients_impl.gradients(z, x)
685        return g[0]
686
687      self.assertEqual(Foo().numpy(), 4.0)
688
689
690class StopGradientTest(test_util.TensorFlowTestCase):
691
692  def testStopGradient(self):
693    with ops.Graph().as_default():
694      inp = constant(1.0, shape=[100, 32], name="in")
695      out = array_ops.stop_gradient(inp)
696      igrad = gradients.gradients(out, inp)[0]
697    assert igrad is None
698
699
700class PreventGradientTest(test_util.TensorFlowTestCase):
701
702  def testPreventGradient(self):
703    with ops.Graph().as_default():
704      inp = constant(1.0, shape=[100, 32], name="in")
705      out = array_ops.prevent_gradient(inp)
706      with self.assertRaisesRegex(LookupError, "explicitly disabled"):
707        _ = gradients.gradients(out, inp)
708
709
710class HessianVectorProductTest(test_util.TensorFlowTestCase):
711
712  @test_util.run_v1_only("b/120545219")
713  def testHessianVectorProduct(self):
714    # Manually compute the Hessian explicitly for a low-dimensional problem
715    # and check that HessianVectorProduct matches multiplication by the
716    # explicit Hessian.
717    # Specifically, the Hessian of f(x) = x^T A x is
718    # H = A + A^T.
719    # We expect HessianVectorProduct(f(x), x, v) to be H v.
720    m = 4
721    rng = np.random.RandomState([1, 2, 3])
722    mat_value = rng.randn(m, m).astype("float32")
723    v_value = rng.randn(m, 1).astype("float32")
724    x_value = rng.randn(m, 1).astype("float32")
725    hess_value = mat_value + mat_value.T
726    hess_v_value = np.dot(hess_value, v_value)
727    for use_gpu in [False, True]:
728      with self.cached_session(use_gpu=use_gpu):
729        mat = constant_op.constant(mat_value)
730        v = constant_op.constant(v_value)
731        x = constant_op.constant(x_value)
732        mat_x = math_ops.matmul(mat, x, name="Ax")
733        x_mat_x = math_ops.matmul(array_ops.transpose(x), mat_x, name="xAx")
734        hess_v = gradients_impl._hessian_vector_product(x_mat_x, [x], [v])[0]
735        hess_v_actual = self.evaluate(hess_v)
736      self.assertAllClose(hess_v_value, hess_v_actual)
737
738
739class HessianTest(test_util.TensorFlowTestCase):
740
741  @test_util.run_v1_only("b/120545219")
742  def testHessian1D(self):
743    # Manually compute the Hessian explicitly for a low-dimensional problem
744    # and check that `hessian` matches. Specifically, the Hessian of
745    # f(x) = x^T A x is H = A + A^T.
746    m = 4
747    rng = np.random.RandomState([1, 2, 3])
748    mat_value = rng.randn(m, m).astype("float32")
749    x_value = rng.randn(m).astype("float32")
750    hess_value = mat_value + mat_value.T
751    with self.session():
752      mat = constant_op.constant(mat_value)
753      x = constant_op.constant(x_value)
754      x_mat_x = math_ops.reduce_sum(x[:, None] * mat * x[None, :])
755      hess = gradients.hessians(x_mat_x, x)[0]
756      hess_actual = self.evaluate(hess)
757    self.assertAllClose(hess_value, hess_actual)
758
759  @test_util.run_v1_only("b/120545219")
760  def testHessian1D_multi(self):
761    # Test the computation of the hessian with respect to multiple tensors
762    m = 4
763    n = 3
764    rng = np.random.RandomState([1, 2, 3])
765    mat_values = [rng.randn(m, m).astype("float32") for _ in range(n)]
766    x_values = [rng.randn(m).astype("float32") for _ in range(n)]
767    hess_values = [mat_value + mat_value.T for mat_value in mat_values]
768    with self.session():
769      mats = [constant_op.constant(mat_value) for mat_value in mat_values]
770      xs = [constant_op.constant(x_value) for x_value in x_values]
771      xs_mats_xs = [
772          math_ops.reduce_sum(x[:, None] * mat * x[None, :])
773          for x, mat in zip(xs, mats)
774      ]
775      hessians = gradients.hessians(xs_mats_xs, xs)
776      hessians_actual = [hess.eval() for hess in hessians]
777    for hess_value, hess_actual in zip(hess_values, hessians_actual):
778      self.assertAllClose(hess_value, hess_actual)
779
780  @test_util.run_v1_only("b/120545219")
781  def testHessianInvalidDimension(self):
782    for shape in [(10, 10), None]:
783      with self.cached_session():
784        x = array_ops.placeholder(dtypes.float32, shape)
785        # Expect a ValueError because the dimensions are wrong
786        with self.assertRaises(ValueError):
787          gradients.hessians(x, x)
788
789  @test_util.run_v1_only("b/120545219")
790  def testHessian2D_square_matrix(self):
791    # Manually compute the Hessian explicitly for a low-dimensional problem
792    # and check that `hessian` matches. Specifically, the Hessian of
793    # f(x) = 1/2 * x^T * x is H = constant (block identity matrix)
794    m = 3
795    rng = np.random.RandomState([1, 2, 3])
796    x_value = rng.randn(m, m).astype("float32")
797    with self.session():
798      x = constant_op.constant(x_value)
799      x_square = math_ops.reduce_sum(
800          math_ops.matmul(array_ops.transpose(x), x) * 0.5
801      )
802      hess = gradients.hessians(x_square, x)[0]
803      hess_actual = self.evaluate(hess)
804    hess_value = np.bmat([
805        [elem*np.ones((m, m)) for elem in vec]
806        for vec in np.eye(m)
807    ]).astype("float32")
808    self.assertAllEqual((m, m, m, m), hess_actual.shape)
809    self.assertAllClose(hess_value, hess_actual.reshape((m * m, m * m)))
810
811  @test_util.run_v1_only("b/120545219")
812  def testHessian2D_non_square_matrix(self):
813    m = 3
814    n = 4
815    rng = np.random.RandomState([1, 2, 3])
816    x_value = rng.randn(m, n).astype("float32")
817    with self.session():
818      x = constant_op.constant(x_value)
819      x_square = math_ops.reduce_sum(
820          math_ops.matmul(array_ops.transpose(x), x) * 0.5
821      )
822      hess = gradients.hessians(x_square, x)[0]
823      hess_actual = self.evaluate(hess)
824    hess_value = np.bmat([
825        [elem*np.ones((n, n)) for elem in vec]
826        for vec in np.eye(m)
827    ]).astype("float32")
828    self.assertAllEqual((m, n, m, n), hess_actual.shape)
829    self.assertAllClose(hess_value, hess_actual.reshape((m * n, m * n)))
830
831
832class IndexedSlicesToTensorTest(test_util.TensorFlowTestCase):
833
834  @test_util.run_v1_only("b/120545219")
835  def testIndexedSlicesToTensor(self):
836    with self.cached_session():
837      np_val = np.random.rand(4, 4, 4, 4).astype(np.float32)
838      c = constant_op.constant(np_val)
839      c_sparse = math_ops._as_indexed_slices(c)
840      self.assertAllEqual(np_val.shape, c_sparse.dense_shape)
841      c_dense = math_ops.multiply(c_sparse, 1.0)
842      self.assertAllClose(np_val, self.evaluate(c_dense))
843
844  @test_util.run_v1_only("b/120545219")
845  def testIndexedSlicesToTensorList(self):
846    with self.cached_session():
847      numpy_list = []
848      dense_list = []
849      sparse_list = []
850      for _ in range(3):
851        np_val = np.random.rand(4, 4, 4, 4).astype(np.float32)
852        c = constant_op.constant(np_val)
853        c_sparse = math_ops._as_indexed_slices(c)
854        numpy_list.append(np_val)
855        dense_list.append(c)
856        sparse_list.append(c_sparse)
857      packed_dense = array_ops.stack(dense_list)
858      packed_sparse = array_ops.stack(sparse_list)
859      self.assertAllClose(packed_dense, self.evaluate(packed_sparse))
860
861  @test_util.run_v1_only("b/120545219")
862  def testInt64Indices(self):
863    with self.cached_session():
864      np_val = np.random.rand(4, 4, 4, 4).astype(np.float32)
865      c = constant_op.constant(np_val)
866      c_sparse = math_ops._as_indexed_slices(c)
867      c_sparse = indexed_slices.IndexedSlices(
868          c_sparse.values,
869          math_ops.cast(c_sparse.indices, dtypes.int64), c_sparse.dense_shape)
870      self.assertAllEqual(np_val.shape, c_sparse.dense_shape)
871      c_dense = math_ops.multiply(c_sparse, 1.0)
872      self.assertAllClose(np_val, self.evaluate(c_dense))
873
874  @test_util.run_v1_only("b/120545219")
875  def testWarnings(self):
876    # TODO(gunan) Reenable after this issue is fixed:
877    # https://github.com/google/protobuf/issues/2812
878    if sys.version_info >= (3, 5):
879      self.skipTest("Skipped test for Python 3.5+")
880
881    # Smaller than the threshold: no warning.
882    c_sparse = indexed_slices.IndexedSlices(
883        array_ops.placeholder(dtypes.float32),
884        array_ops.placeholder(dtypes.int32), constant([4, 4, 4, 4]))
885    with warnings.catch_warnings(record=True) as w:
886      math_ops.multiply(c_sparse, 1.0)
887    self.assertEqual(0, len(w))
888
889    # Greater than or equal to the threshold: warning.
890    c_sparse = indexed_slices.IndexedSlices(
891        array_ops.placeholder(dtypes.float32),
892        array_ops.placeholder(dtypes.int32), constant([100, 100, 100, 100]))
893    # "always" filter prevents the warning from being suppressed if it was
894    # already triggered in a different test.
895    warnings.simplefilter("always")
896    with warnings.catch_warnings(record=True) as w:
897      math_ops.multiply(c_sparse, 1.0)
898    self.assertEqual(1, len(w))
899    self.assertIn(
900        "with 100000000 elements. This may consume a large amount of memory.",
901        str(w[0].message))
902
903    # Unknown dense shape: warning.
904    c_sparse = indexed_slices.IndexedSlices(
905        array_ops.placeholder(dtypes.float32),
906        array_ops.placeholder(dtypes.int32),
907        array_ops.placeholder(dtypes.int32))
908    with warnings.catch_warnings(record=True) as w:
909      math_ops.multiply(c_sparse, 1.0)
910    self.assertEqual(1, len(w))
911    self.assertIn(
912        "of unknown shape. This may consume a large amount of memory.",
913        str(w[0].message))
914
915
916class OnlyRealGradientsTest(test_util.TensorFlowTestCase):
917
918  @test_util.run_v1_only("b/120545219")
919  def testRealOnly(self):
920    x = constant_op.constant(7+3j, dtype=dtypes.complex64)
921    y = math_ops.square(x)
922    with self.assertRaisesRegex(
923        TypeError, r"Gradients of complex tensors .* must set grad_ys "
924        r"\(y\.dtype = complex64\)"):
925      gradients.gradients(y, x)
926
927
928class ResourceCondTest(test_util.TensorFlowTestCase):
929
930  @test_util.run_v1_only("b/120545219")
931  def testBasic(self):
932    gamma = resource_variable_ops.ResourceVariable(
933        np.random.random((3,)),
934        dtype="float32", name="gamma")
935
936    inputs = array_ops.ones(shape=(3,), dtype="float32")
937
938    def TestFn():
939      output = inputs + gamma
940      return output
941
942    training = array_ops.placeholder_with_default(True, shape=())
943    output = control_flow_ops.cond(
944        training, TestFn, lambda: inputs)
945
946    loss = output
947
948    grads = gradients.gradients(
949        loss, [gamma])
950    self.assertNotIn(None, grads)
951
952
953class GetDependentVariablesTest(test_util.TensorFlowTestCase):
954
955  def testNoVariables(self):
956    with ops.Graph().as_default():
957      func = lambda x: array_ops.identity(x) + 5.0
958      input_t = constant_op.constant(2.0)
959      result_t = func(input_t)
960      dependent_vars = custom_gradient._get_dependent_variables(
961          [input_t], [result_t])
962
963      # There are no variables.
964      self.assertEqual(dependent_vars, [])
965
966  def testVariablesOutside(self):
967    with ops.Graph().as_default():
968      init = constant_op.constant(100.0)
969      var = variables.Variable(init)
970
971      # The variable is closed over. It should be found.
972      func = lambda x: array_ops.identity(x) + 5.0 + var
973
974      input_t = constant_op.constant(2.0)
975      result_t = func(input_t)
976      dependent_vars = custom_gradient._get_dependent_variables(
977          [input_t], [result_t])
978      self.assertEqual(dependent_vars, [var])
979
980  def testVariableSamePrefix(self):
981    with ops.Graph().as_default():
982      var_name = "my_variable"
983      v_z = variable_scope.get_variable(var_name, shape=())
984      v_o = variable_scope.get_variable(var_name + "_ones", shape=())
985
986      # The variable is closed over. It should be found.
987      func = lambda x: array_ops.identity(x) + 5.0 + v_z + v_o
988
989      input_t = constant_op.constant(2.0)
990      result_t = func(input_t)
991      dependent_vars = custom_gradient._get_dependent_variables(
992          [input_t], [result_t])
993      self.assertEqual(set(dependent_vars), set([v_o, v_z]))
994
995  def testVariablesOutsideButDSeparated(self):
996    with ops.Graph().as_default():
997      init = constant_op.constant(100.0)
998      var = variables.Variable(init)
999
1000      # The variable is d-separated by the inputs. It should not be found.
1001      input_t = array_ops.identity(var) * 5.0
1002
1003      func = lambda x: array_ops.identity(x) + 5.0
1004      result_t = func(input_t)
1005      dependent_vars = custom_gradient._get_dependent_variables(
1006          [input_t], [result_t])
1007      self.assertEqual(dependent_vars, [])
1008
1009  def testVariablesOutsideAndNonDifferentiable(self):
1010    with ops.Graph().as_default():
1011      init = constant_op.constant(100.0, shape=(5,))
1012      var = variables.Variable(init, shape=(5,))
1013
1014      def _Func(x):
1015        # non-differentiable dependency on var.
1016        # the variable should not be found.
1017        y = array_ops.ones_like(var)
1018        return array_ops.identity(x) + 5.0 + y
1019
1020      input_t = constant_op.constant(2.0)
1021      result_t = _Func(input_t)
1022      dependent_vars = custom_gradient._get_dependent_variables(
1023          [input_t], [result_t])
1024      self.assertEqual(dependent_vars, [])
1025
1026  def testGetVariableByName(self):
1027    with context.graph_mode():
1028      init = constant_op.constant(100.0)
1029      var = variable_scope.variable(init, name="a/replica_1")
1030      if isinstance(var, variables.RefVariable):
1031        var._variable = array_ops.identity(var, name="a")
1032      else:
1033        var._handle = array_ops.identity(var, name="a")
1034      var2 = custom_gradient.get_variable_by_name("a")
1035      self.assertEqual(var2.name, var.name)
1036
1037  def testVariablesOutsideAndNonTrainable(self):
1038    with ops.Graph().as_default():
1039      init = constant_op.constant(100.0, shape=(5,))
1040
1041      # Both variables are used in the function but only the trainable one
1042      # should be found.
1043      var_trainable = variables.Variable(init, shape=(5,))
1044      var_nontrainable = variables.Variable(init, shape=(5,), trainable=False)
1045
1046      def _Func(x):
1047        del x
1048        return var_trainable + var_nontrainable
1049
1050      input_t = constant_op.constant(2.0)
1051      result_t = _Func(input_t)
1052      dependent_vars = custom_gradient._get_dependent_variables(
1053          [input_t], [result_t])
1054      self.assertEqual(dependent_vars, [var_trainable])
1055
1056  def testVariablesOutsideAndCustomGradient(self):
1057    with ops.Graph().as_default():
1058      init = constant_op.constant(100.0, shape=(5,))
1059      var = variables.Variable(init, shape=(5,))
1060
1061      @custom_gradient.custom_gradient
1062      def _MyOnesLike(x):
1063        """Dummy version of ones_like which defines a gradient."""
1064
1065        output = array_ops.ones_like(x)
1066
1067        def _Grad(dy):
1068          return array_ops.identity(dy)
1069
1070        return output, _Grad
1071
1072      def _Func(x):
1073        # non-differentiable operation with custom gradient.
1074        # The variable should be found.
1075        y = _MyOnesLike(var)
1076        return array_ops.identity(x) + 5.0 + y
1077
1078      input_t = constant_op.constant(2.0)
1079      result_t = _Func(input_t)
1080      dependent_vars = custom_gradient._get_dependent_variables(
1081          [input_t], [result_t])
1082      self.assertEqual(dependent_vars, [var])
1083
1084
1085class CustomGradientTest(test_util.TensorFlowTestCase, parameterized.TestCase):
1086
1087  def testCustomGradientTrivial(self):
1088
1089    @custom_gradient.custom_gradient
1090    def MyIdentity(x):
1091
1092      def Grad(dy):
1093        return [3 * dy]
1094
1095      return x, Grad
1096
1097    with ops.Graph().as_default():
1098      x = constant(3.)
1099      y = MyIdentity(MyIdentity(x))
1100      dy = gradients.gradients(y, x)[0]
1101      with session.Session():
1102        self.assertEqual(9., self.evaluate(dy))
1103
1104  def testCustomGradient(self):
1105
1106    @custom_gradient.custom_gradient
1107    def MyMultiply(x1, x2):
1108      result = x1 * x2
1109
1110      def Grad(dy):
1111        # Switched the ordering here.
1112        return [dy * x1, dy * x2]
1113
1114      return result, Grad
1115
1116    with ops.Graph().as_default():
1117      x1 = constant(3.)
1118      x2 = constant(5.)
1119      y = MyMultiply(x1, x2)
1120      dy = gradients.gradients(y, [x1, x2])
1121
1122      self.assertAllEqual([3., 5.], self.evaluate(dy))
1123
1124  def testCustomGradientClass(self):
1125
1126    class Model:
1127
1128      @custom_gradient.custom_gradient
1129      def Multiply(self, x1, x2):
1130        result = x1 * x2
1131        grad = lambda dy: (dy * x1, dy * x2)
1132        return result, grad
1133
1134    with ops.Graph().as_default():
1135      x1 = constant(3.)
1136      x2 = constant(5.)
1137      m = Model()
1138      y = m.Multiply(x1, x2)
1139      dy = gradients.gradients(y, [x1, x2])
1140      self.assertAllEqual([3., 5.], self.evaluate(dy))
1141
1142  def testCustomGradientErrors(self):
1143
1144    @custom_gradient.custom_gradient
1145    def F(x):
1146
1147      def Grad(_):
1148        raise RuntimeError("x")
1149
1150      return x, Grad
1151
1152    with ops.Graph().as_default():
1153      x = constant(1.0)
1154      y = F(x)
1155      with self.assertRaises(RuntimeError):
1156        gradients.gradients(y, x)
1157
1158  def testCustomGradientWithVariables(self):
1159
1160    @custom_gradient.custom_gradient
1161    def F(x):
1162      out = core_layers.dense(x, 3, use_bias=False)
1163
1164      def Grad(out_grad, variables=None):  # pylint: disable=redefined-outer-name
1165        self.assertEqual(1, len(variables))  # pylint: disable=g-generic-assert
1166        grads = gradients.gradients(out, [x, variables[0]], grad_ys=out_grad)
1167        return grads[0], [array_ops.ones((4, 3))]
1168
1169      return out, Grad
1170
1171    with ops.Graph().as_default():
1172      x = array_ops.ones((2, 4))
1173      with variable_scope.variable_scope("f", use_resource=True) as vs:
1174        y = F(x)
1175        all_vars = vs.global_variables()
1176        assert len(all_vars) == 1
1177      grads = gradients.gradients(y, [x, all_vars[0]])
1178      for g in grads:
1179        self.assertIsNotNone(g)
1180
1181      self.evaluate(variables.global_variables_initializer())
1182      dw = self.evaluate(math_ops.reduce_sum(grads[1]))
1183      self.assertEqual(12., dw)
1184
1185  def testCustomGradientWithCapture(self):
1186    with ops.Graph().as_default():
1187      x = constant(3.)
1188
1189      @framework_function.Defun(dtypes.float32)
1190      def F(y):
1191
1192        @custom_gradient.custom_gradient
1193        def MyMultiply(x1, x2):
1194          result = x1 * x2
1195
1196          def Grad(dy):
1197            # Switched the ordering here.
1198            return [dy * x1, dy * x2]
1199
1200          return result, Grad
1201
1202        res = MyMultiply(x, y)
1203        return gradients.gradients(res, [y])
1204
1205      y = constant(5.)
1206      dy = F(y)
1207      self.assertAllEqual(5., self.evaluate(dy))
1208
1209  def testCustomGradientWithVariablesNoFalsePositives(self):
1210
1211    @custom_gradient.custom_gradient
1212    def F(x):
1213      out = core_layers.dense(x, 3, use_bias=False)
1214
1215      def Grad(out_grad, variables=None):  # pylint: disable=redefined-outer-name
1216        self.assertEqual(1, len(variables))  # pylint: disable=g-generic-assert
1217        grads = gradients.gradients(out, [x, variables[0]], grad_ys=out_grad)
1218        return grads[0], [array_ops.ones((3, 3))]
1219
1220      return out, Grad
1221
1222    with ops.Graph().as_default():
1223      with variable_scope.variable_scope("f", use_resource=True) as vs:
1224        a = array_ops.ones((2, 4))
1225
1226        # Variabes in these layers shouldn't be picked up by the decorator.
1227        b = core_layers.dense(a, 3, use_bias=False)
1228        c = core_layers.dense(b, 3, use_bias=False)
1229        x = core_layers.dense(b, 3, use_bias=False) + c
1230
1231        # Only the variables used in F.
1232        y = F(x)
1233
1234        all_vars = vs.global_variables()
1235        assert len(all_vars) == 4
1236      grads = gradients.gradients(y, [x] + all_vars)
1237      _, var_grads = grads[0], grads[1:]
1238      for g in grads:
1239        self.assertIsNotNone(g)
1240
1241      self.evaluate(variables.global_variables_initializer())
1242      dw = self.evaluate(math_ops.reduce_sum(var_grads[-1]))
1243      self.assertEqual(9., dw)
1244
1245  def testCustomGradientWithVariablesEager(self):
1246    with context.eager_mode():
1247      layer = core_layers.Dense(4, use_bias=False)
1248
1249      @custom_gradient.custom_gradient
1250      def F(x):
1251        out = layer(x)
1252
1253        def Grad(out_grad, variables=None):  # pylint: disable=redefined-outer-name
1254          del out_grad
1255          self.assertEqual(1, len(variables))  # pylint: disable=g-generic-assert
1256          return (array_ops.ones((3, 2)),
1257                  [array_ops.ones((2, 4))])
1258
1259        return out, Grad
1260
1261      x = array_ops.ones((3, 2)) + 2.
1262      with backprop.GradientTape() as tape:
1263        tape.watch(x)
1264        y = F(x)
1265      w, = layer.variables
1266      dx, dw = tape.gradient(y, [x, w])
1267      self.assertEqual(6., math_ops.reduce_sum(dx).numpy())
1268      self.assertEqual(8., math_ops.reduce_sum(dw).numpy())
1269
1270  @test_util.run_v1_only("b/120545219")
1271  def testCustomGradientErrorsWithNonResourceVariables(self):
1272
1273    def F(x, use_resource=False):
1274      with variable_scope.variable_scope("f", use_resource=use_resource):
1275        out = core_layers.dense(x, 4, use_bias=False)
1276
1277      def Grad(out_grad, variables=None):  # pylint: disable=redefined-outer-name
1278        del out_grad
1279        self.assertEqual(1, len(variables))  # pylint: disable=g-generic-assert
1280        return (array_ops.ones((3, 2)), [array_ops.ones((2, 4))])
1281
1282      return out, Grad
1283
1284    @custom_gradient.custom_gradient
1285    def FResource(x):
1286      return F(x, use_resource=True)
1287
1288    @custom_gradient.custom_gradient
1289    def FNonResource(x):
1290      return F(x, use_resource=False)
1291
1292    x = array_ops.ones((3, 2)) + 2.
1293
1294    # Wrapping scope has use_resource=True but inner scope sets to False. Fails.
1295    with variable_scope.variable_scope("vs1", use_resource=True):
1296      with self.assertRaisesWithPredicateMatch(TypeError,
1297                                               "must be `ResourceVariable`s"):
1298        FNonResource(x)
1299
1300    # Wrapping scope has use_resource=False but inner scope sets to True.
1301    # Passes.
1302    with variable_scope.variable_scope("vs2", use_resource=False):
1303      FResource(x)
1304
1305  @parameterized.parameters(True, False)
1306  def testCustomGradientVariablesKwonlyArgs(self, anonymous_varargs):
1307    with context.eager_mode():
1308      x_captured = variables.Variable(3.)  # Used by FuncMult
1309      @custom_gradient.custom_gradient
1310      def FuncMult(x):
1311        def ActualGrad(dy, variables):  # pylint: disable=redefined-outer-name
1312          self.assertLen(variables, 1)
1313          self.assertIs(variables[0], x_captured)
1314          x_captured_grad = 5. * x * dy
1315          return (4. * x_captured * dy, [x_captured_grad])
1316        # Define the returned GradMult, using varargs; "variables" is kwonlyarg
1317        if anonymous_varargs:
1318          def GradMult(dy, *, variables=None):  # pylint: disable=redefined-outer-name
1319            return ActualGrad(dy, variables)
1320        else:
1321          def GradMult(*dys, variables=None):  # pylint: disable=redefined-outer-name
1322            return ActualGrad(dys[0], variables)
1323
1324        return x * x_captured, GradMult
1325
1326      x = variables.Variable(6.)
1327      with backprop.GradientTape(persistent=True) as g:
1328        y = FuncMult(x)
1329      self.assertAllEqual(g.gradient(y, x), 4. * 3.)
1330
1331  def testWithNumpyInputs(self):
1332    with context.eager_mode():
1333
1334      @custom_gradient.custom_gradient
1335      def F(x):
1336        out = x
1337
1338        def Grad(_):
1339          return (None, None)
1340
1341        return out, Grad
1342
1343      x = np.ones((3, 2), dtype=np.float32)
1344      # Smoke test to ensure numpy inputs are accepted
1345      F(x)
1346
1347  @test_util.run_v1_only("b/120545219")
1348  def testRVGradientsDynamicCond(self):
1349    with self.cached_session():
1350      alpha = resource_variable_ops.ResourceVariable(
1351          np.random.random((1,)),
1352          dtype="float32")
1353
1354      conditional = array_ops.placeholder_with_default(True, shape=())
1355      output = control_flow_ops.cond(
1356          conditional, lambda: alpha * 2, lambda: alpha * 3)
1357
1358      g, = gradients_impl.gradients(output, alpha)
1359      self.evaluate(variables.global_variables_initializer())
1360      self.assertAllEqual(g, [2.0])
1361      self.assertAllEqual(g.eval(feed_dict={conditional: False}), [3.0])
1362
1363  def testRecursiveCustomGradient(self):
1364    @custom_gradient.custom_gradient
1365    def F(x):
1366      out = core_layers.dense(x, 3, use_bias=False)
1367
1368      def Grad(out_grad, variables=None):  # pylint: disable=redefined-outer-name
1369        self.assertEqual(1, len(variables))  # pylint: disable=g-generic-assert
1370        grads = gradients.gradients(out, [x, variables[0]], grad_ys=out_grad)
1371        return grads[0], [array_ops.ones((4, 3))]
1372
1373      return out, Grad
1374
1375    @custom_gradient.custom_gradient
1376    def DoubleF(x):
1377      out = F(x)
1378
1379      def Grad(out_grad, variables=None):  # pylint: disable=redefined-outer-name
1380        self.assertEqual(1, len(variables))  # pylint: disable=g-generic-assert
1381        grads = gradients.gradients(out, [x, variables[0]], grad_ys=out_grad)
1382        return grads[0], [array_ops.ones((4, 3))]
1383
1384      return out, Grad
1385    with ops.Graph().as_default():
1386      x = array_ops.ones((2, 4))
1387      with variable_scope.variable_scope("f", use_resource=True) as vs:
1388        y = DoubleF(x)
1389        all_vars = vs.global_variables()
1390        assert len(all_vars) == 1
1391      grads = gradients.gradients(y, [x, all_vars[0]])
1392      for g in grads:
1393        self.assertIsNotNone(g)
1394
1395      self.evaluate(variables.global_variables_initializer())
1396      dw = self.evaluate(math_ops.reduce_sum(grads[1]))
1397      self.assertEqual(12., dw)
1398
1399  @parameterized.named_parameters(
1400      [(("_%s_%s" % (x_struct, y_struct)).replace(" ", "").replace("None", ""),  # pylint: disable=g-complex-comprehension
1401        x_struct, y_struct)
1402       for y_struct in [[None, ()], (None, (), [], (None, ((), None)))]
1403       for x_struct in [(None, ()), (((), ()), [None, None], [], (None, ()))]
1404      ])
1405  @test_util.run_in_graph_and_eager_modes
1406  def testCustomGradientStructuralInputOutput(self, x_struct, y_struct):
1407    """Tests that custom_gradient can handle structured inputs/outputs."""
1408    def Zeros(x):
1409      return nest.map_structure(lambda _: array_ops.zeros([], "float32"), x)
1410    def GetStruct(x):
1411      return nest.map_structure(lambda _: None, x)
1412
1413    def MakeVjp(f, *x):
1414      with backprop.GradientTape(persistent=True) as tape:
1415        tape.watch(nest.flatten(x))
1416        y = f(*x)
1417      def Vjp(dy):
1418        return tape.gradient(y, x, output_gradients=dy)
1419      return y, Vjp
1420
1421    @custom_gradient.custom_gradient
1422    def F(*x):
1423      self.assertEqual(x_struct, GetStruct(x))
1424      def Vjp(*dy):
1425        self.assertEqual(len(nest.flatten(y_struct)),
1426                         len(nest.flatten(dy)))
1427        return nest.flatten(Zeros(x_struct))
1428      return Zeros(y_struct), Vjp
1429
1430    x, dy = Zeros([x_struct, y_struct])
1431    y, vjp = MakeVjp(F, *x)
1432    dx = vjp(dy)
1433    self.assertEqual(x_struct, GetStruct(dx))
1434    self.assertEqual(y_struct, GetStruct(y))
1435
1436
1437class TensorListGradientsTest(test_util.TensorFlowTestCase):
1438
1439  def testDefaultGradYs(self):
1440    with ops.Graph().as_default():
1441      tl = list_ops.empty_tensor_list(
1442          element_dtype=dtypes.float32,
1443          element_shape=ops.convert_to_tensor([], dtype=dtypes.int32))
1444      a = constant(1.0)
1445      tl = list_ops.tensor_list_push_back(tl, a)
1446
1447      grad_tl = list_ops.empty_tensor_list(
1448          element_dtype=dtypes.float32,
1449          element_shape=ops.convert_to_tensor([], dtype=dtypes.int32))
1450      grad_tl = list_ops.tensor_list_push_back(tl, constant(5.0))
1451
1452      grad = gradients.gradients(tl, a, grad_ys=grad_tl)[0]
1453
1454      self.assertEqual(self.evaluate(grad), 5.)
1455
1456
1457class VariablesGradientTest(test_util.TensorFlowTestCase,
1458                            parameterized.TestCase):
1459
1460  def _TestFnVariablesGradient(self, inputs, test_fn, vars_to_grad):
1461    """Returns gradients of `test_model` with respect to `vars_to_grad`."""
1462
1463    test_fn_re = custom_gradient.recompute_grad(test_fn)
1464
1465    with backprop.GradientTape(persistent=True) as tape:
1466      tape.watch(vars_to_grad)
1467      out_re = test_fn_re(inputs, vars_to_grad)
1468      out = test_fn(inputs, vars_to_grad)
1469
1470    grads_re = tape.gradient(out_re, vars_to_grad)
1471    grads = tape.gradient(out, vars_to_grad)
1472
1473    return grads_re, grads
1474
1475  def _grad(self, f, argnums=0):
1476    """Return a function which computes the gradient of `f`."""
1477
1478    def F(*params):
1479      with backprop.GradientTape() as tape:
1480        tape.watch(params)
1481        outputs = f(*params)
1482      return tape.gradient(
1483          outputs,
1484          params[argnums],
1485          unconnected_gradients=unconnected_gradients.UnconnectedGradients.ZERO)
1486
1487    return F
1488
1489  def _test_gradients(self, f, inputs, order, delta=1e-3, rtol=1e-2, atol=1e-6):
1490    """Tests backward jacobians of `f`'s [0, `order`)-order gradients."""
1491    if order < 1:
1492      raise ValueError(
1493          "`order` should be a positive integer, got '{}'.".format(order))
1494    if order > 1:
1495      self._test_gradients(
1496          f=self._grad(f),
1497          inputs=inputs,
1498          order=order - 1,
1499          delta=delta,
1500          rtol=rtol,
1501          atol=atol)
1502    sym_jac_back, num_jac = gradient_checker_v2.compute_gradient(
1503        f, inputs, delta=delta)
1504    self.assertAllClose(num_jac, sym_jac_back, rtol=rtol, atol=atol)
1505
1506  def testRecomputeGradWrapped(self):
1507
1508    def f(x):  # pylint: disable=invalid-name
1509      return 2 * x
1510
1511    g = custom_gradient.recompute_grad(f)
1512    self.assertIs(g.__wrapped__, f)
1513
1514  def testRecomputeGradZeroSizeInput(self):
1515
1516    def F(x):
1517      return 2 * x
1518
1519    x = array_ops.constant(())
1520    grads_re = self._grad(custom_gradient.recompute_grad(F))(x)
1521    grads = self._grad(F)(x)
1522    self.assertAllClose(grads_re, grads)
1523
1524    f_graph = function.defun(F, input_signature=[tensor_spec.TensorSpec(None)])
1525    grads_re = self._grad(custom_gradient.recompute_grad(f_graph))(x)
1526    grads = self._grad(f_graph)(x)
1527    self.assertAllClose(grads_re, grads)
1528
1529  def testRecomputeGradDifferentDtypesInputs(self):
1530
1531    def F(x1, x2):
1532      return 2 * x1, 2 * x2
1533
1534    x1 = array_ops.constant(1, dtype=dtypes.int32)
1535    x2 = array_ops.constant(1., dtype=dtypes.float32)
1536    grads_re = self._grad(custom_gradient.recompute_grad(F))(x1, x2)
1537    grads = self._grad(F)(x1, x2)
1538    self.assertAllClose(grads_re, grads)
1539
1540    f_graph = function.defun(
1541        F,
1542        input_signature=[
1543            tensor_spec.TensorSpec(None, dtype=dtypes.int32),
1544            tensor_spec.TensorSpec(None, dtype=dtypes.float32),
1545        ])
1546    grads_re = self._grad(custom_gradient.recompute_grad(f_graph))(x1, x2)
1547    grads = self._grad(f_graph)(x1, x2)
1548    self.assertAllClose(grads_re, grads)
1549
1550  @test_util.run_v2_only
1551  def testCustomGradientRecomputeGradHigherOrder(self):
1552
1553    @custom_gradient.recompute_grad
1554    def F(x):
1555      return math_ops.reduce_prod(math_ops.tanh(x)**2)
1556
1557    self._test_gradients(F, [constant_op.constant([1.])], order=3)
1558
1559  @test_util.run_in_graph_and_eager_modes
1560  def testFnRecompute(self):
1561    """Checks that recompute_grad works grads of function args."""
1562
1563    def TestFn(inputs, input_vars):
1564      return inputs * input_vars
1565
1566    def TestFnSeq(inputs, input_vars):
1567      return (inputs * input_vars, inputs * input_vars * 2.0)
1568
1569    with variable_scope.variable_scope("test", use_resource=True):
1570      test_var = variable_scope.get_variable(
1571          name="test_var",
1572          shape=10,
1573          trainable=True,
1574      )
1575      self.evaluate(test_var.assign(np.ones([10])))
1576      test_input = constant(np.ones((10, 10), dtype=np.float32))
1577
1578      grads_re, grads = self._TestFnVariablesGradient(test_input, TestFn,
1579                                                      test_input)
1580
1581      grads_re = self.evaluate(grads_re)
1582      grads = self.evaluate(grads)
1583      for g, g_re in zip(grads, grads_re):
1584        self.assertAllClose(g, g_re)
1585
1586      grads_re, grads = self._TestFnVariablesGradient(test_input, TestFn,
1587                                                      test_var)
1588      grads_re = self.evaluate(grads_re)
1589      grads = self.evaluate(grads)
1590      for g, g_re in zip(grads, grads_re):
1591        self.assertAllClose(g, g_re)
1592
1593      # Regression test for wrapping sequence outputting functions.
1594      grads_re, grads = self._TestFnVariablesGradient(test_input, TestFnSeq,
1595                                                      test_input)
1596      grads_re = self.evaluate(grads_re)
1597      grads = self.evaluate(grads)
1598      for g, g_re in zip(grads, grads_re):
1599        self.assertAllClose(g, g_re)
1600
1601      grads_re, grads = self._TestFnVariablesGradient(test_input, TestFnSeq,
1602                                                      test_var)
1603      grads_re = self.evaluate(grads_re)
1604      grads = self.evaluate(grads)
1605      for g, g_re in zip(grads, grads_re):
1606        self.assertAllClose(g, g_re)
1607
1608  @parameterized.parameters(set((True, context.executing_eagerly())))
1609  def testFnRecomputeWithScopeGradient(self, use_tape):
1610    """Checks that recompute_grad works with var scope and GradientTape."""
1611
1612    def TestFn(input_t):
1613      with variable_scope.variable_scope("inner_scope"):
1614        test_var = variable_scope.get_variable(
1615            name="test_var",
1616            shape=10,
1617            trainable=True,
1618        )
1619        return input_t * test_var
1620
1621    test_input_t = constant(np.zeros((10, 10), dtype=np.float32))
1622
1623    with variable_scope.variable_scope(
1624        "output_scope", reuse=variable_scope.AUTO_REUSE, use_resource=True):
1625      test_fn_re = custom_gradient.recompute_grad(TestFn)
1626
1627      with test_util.AbstractGradientTape(
1628          use_tape=use_tape, persistent=True) as tape:
1629        out_re = test_fn_re(test_input_t)
1630        out = TestFn(test_input_t)
1631
1632    self.evaluate(variables.global_variables_initializer())
1633    grads_re = tape.gradient(out_re, variables.trainable_variables())
1634    grads = tape.gradient(out, variables.trainable_variables())
1635
1636    grads_re = self.evaluate(grads_re)
1637    grads = self.evaluate(grads)
1638    for g, g_re in zip(grads, grads_re):
1639      self.assertAllClose(g, g_re)
1640
1641  @test_util.run_in_graph_and_eager_modes
1642  def testFnRecomputeSameTensor(self):
1643    """Check recompute_grad when wrapped f called as f(x, x) - b/147369366."""
1644
1645    def TestFnMul(x, y):
1646      return x * y
1647
1648    def TestFnSingleVar(x, y):
1649      # pylint: disable=unused-argument
1650      return x
1651
1652    with variable_scope.variable_scope("test", use_resource=True):
1653      x = array_ops.ones((10))
1654
1655      grads_re, grads = self._TestFnVariablesGradient(x, TestFnMul,
1656                                                      x)
1657      grads_re = self.evaluate(grads_re)
1658      grads = self.evaluate(grads)
1659      for g, g_re in zip(grads, grads_re):
1660        self.assertAllClose(g, g_re)
1661
1662      grads_re, grads = self._TestFnVariablesGradient(x, TestFnSingleVar,
1663                                                      x)
1664      grads_re = self.evaluate(grads_re)
1665      grads = self.evaluate(grads)
1666      for g, g_re in zip(grads, grads_re):
1667        self.assertAllClose(g, g_re)
1668
1669
1670class GradPassThroughTest(test_util.TensorFlowTestCase):
1671
1672  @test_util.run_v1_only("b/120545219")
1673  def test_gradients_v1(self):
1674    x = variable_scope.get_variable(
1675        name="x", shape=(), initializer=init_ops.constant_initializer(1.0),
1676        use_resource=True)
1677    z = variable_scope.get_variable(
1678        name="z", shape=(), initializer=init_ops.constant_initializer(3.0),
1679        use_resource=True)
1680
1681    # Verify that assign op is not differentiable
1682    y = state_ops.assign(x, z**2)
1683    grads = gradients.gradients(y, z)
1684    self.assertIsNone(grads[0])
1685
1686    # Verify that when the (non differentiable) assign op is wrapped with
1687    # grad_pass_through, gradients are correctly forwarded to the inputs.
1688    # Form an input as quadratic function of variable z and check that the
1689    # gradient of output wrt to z is correct.
1690    y = custom_gradient.grad_pass_through(
1691        lambda v: state_ops.assign(x, v))(z**2)
1692    grads = gradients.gradients(y, z)
1693
1694    with self.cached_session():
1695      self.evaluate(variables.global_variables_initializer())
1696      self.assertAllClose(grads[0], 6.0)
1697
1698    # Verify that variables involved in the wrapped op do not receive gradients.
1699    y = custom_gradient.grad_pass_through(lambda v: x * v)(z)
1700    grads = gradients.gradients(y, x)
1701    self.assertIsNone(grads[0])
1702
1703  @test_util.run_v2_only
1704  def test_gradients_v2(self):
1705    x = variables.Variable(1.0, name="x")
1706    z = variables.Variable(3.0, name="z")
1707
1708    # Verify that assign op is not differentiable
1709    with backprop.GradientTape() as tape:
1710      y = x.assign(z**2)
1711    grads = tape.gradient(y, z)
1712    self.assertIsNone(grads)
1713
1714    # Verify that when the (non differentiable) assign op is wrapped with
1715    # grad_pass_through, gradients are correctly forwarded to the inputs.
1716    # Form an input as quadratic function of variable z and check that the
1717    # gradient of output wrt to z is correct.
1718    with backprop.GradientTape() as tape:
1719      y = custom_gradient.grad_pass_through(x.assign)(z**2)
1720    grads = tape.gradient(y, z)
1721    self.assertAllClose(grads, 6.0)
1722
1723    # Verify that variables involved in the wrapped op do not receive gradients.
1724    with backprop.GradientTape() as tape:
1725      y = custom_gradient.grad_pass_through(lambda v: x * v)(z)
1726    grads = tape.gradient(y, x)
1727    self.assertIsNone(grads)
1728
1729
1730if __name__ == "__main__":
1731  googletest.main()
1732