• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for tensorflow.ops.gradients."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20import sys
21import warnings
22
23from absl.testing import parameterized
24import numpy as np
25from tensorflow.python.client import session
26from tensorflow.python.eager import backprop
27from tensorflow.python.eager import context
28from tensorflow.python.eager import function
29from tensorflow.python.framework import constant_op
30from tensorflow.python.framework import dtypes
31from tensorflow.python.framework import function as framework_function
32from tensorflow.python.framework import ops
33from tensorflow.python.framework import test_ops
34from tensorflow.python.framework import test_util
35from tensorflow.python.framework.constant_op import constant
36from tensorflow.python.layers import core as core_layers
37from tensorflow.python.ops import array_grad  # pylint: disable=unused-import
38from tensorflow.python.ops import array_ops
39from tensorflow.python.ops import control_flow_grad  # pylint: disable=unused-import
40from tensorflow.python.ops import control_flow_ops
41from tensorflow.python.ops import custom_gradient
42from tensorflow.python.ops import data_flow_grad  # pylint: disable=unused-import
43from tensorflow.python.ops import data_flow_ops  # pylint: disable=unused-import
44from tensorflow.python.ops import functional_ops  # pylint: disable=unused-import
45from tensorflow.python.ops import gradient_checker_v2
46from tensorflow.python.ops import gradients
47from tensorflow.python.ops import gradients_impl
48from tensorflow.python.ops import init_ops
49from tensorflow.python.ops import list_ops
50from tensorflow.python.ops import math_grad  # pylint: disable=unused-import
51from tensorflow.python.ops import math_ops
52from tensorflow.python.ops import nn_grad  # pylint: disable=unused-import
53from tensorflow.python.ops import resource_variable_ops
54from tensorflow.python.ops import state_grad  # pylint: disable=unused-import
55from tensorflow.python.ops import state_ops
56from tensorflow.python.ops import tensor_array_grad  # pylint: disable=unused-import
57from tensorflow.python.ops import tensor_array_ops
58from tensorflow.python.ops import unconnected_gradients
59from tensorflow.python.ops import variable_scope
60from tensorflow.python.ops import variables
61from tensorflow.python.ops.nn_ops import bias_add
62from tensorflow.python.platform import googletest
63from tensorflow.python.util import nest
64
65
66class GradientsTest(test_util.TensorFlowTestCase, parameterized.TestCase):
67
68  def testGradients(self):
69    with ops.Graph().as_default():
70      inp = constant(1.0, shape=[32, 100], name="in")
71      w = constant(1.0, shape=[100, 10], name="w")
72      b = constant(1.0, shape=[10], name="b")
73      xw = math_ops.matmul(inp, w, name="xw")
74      h = bias_add(xw, b, name="h")
75      w_grad = gradients.gradients(h, w)[0]
76    self.assertEqual("MatMul", w_grad.op.type)
77    self.assertEqual(w_grad.op._original_op, xw.op)
78    self.assertTrue(w_grad.op.get_attr("transpose_a"))
79    self.assertFalse(w_grad.op.get_attr("transpose_b"))
80
81  def testUnusedOutput(self):
82    with ops.Graph().as_default():
83      w = constant(1.0, shape=[2, 2])
84      x = constant(1.0, shape=[2, 2])
85      wx = math_ops.matmul(w, x)
86      split_wx = array_ops.split(value=wx, num_or_size_splits=2, axis=0)
87      c = math_ops.reduce_sum(split_wx[1])
88      gw = gradients.gradients(c, [w])[0]
89    self.assertEqual("MatMul", gw.op.type)
90
91  def testColocateGradients(self):
92    with ops.Graph().as_default() as g:
93      w = constant(1.0, shape=[1, 1])
94      x = constant(1.0, shape=[1, 2])
95      with g.device("/device:GPU:0"):
96        wx = math_ops.matmul(w, x)
97      gw = gradients.gradients(wx, [w], colocate_gradients_with_ops=True)[0]
98    self.assertEqual(gw.op.colocation_groups(), wx.op.colocation_groups())
99
100  def testColocateGradientsWithAggregation(self):
101    with ops.Graph().as_default() as g:
102      with g.device("/device:GPU:1"):
103        w = constant(1.0, shape=[1, 1])
104      x = constant(1.0, shape=[1, 2])
105      y = constant(1.0, shape=[1, 2])
106      wx = math_ops.matmul(w, x)
107      wy = math_ops.matmul(w, y)
108      with g.device("/device:GPU:0"):
109        z = wx + wy
110
111      gw1 = gradients.gradients(z, [w], colocate_gradients_with_ops=True)[0]
112      self.assertEqual(gw1.op.colocation_groups(), wx.op.colocation_groups())
113
114      gw2 = gradients.gradients(z, [w], colocate_gradients_with_ops=False)[0]
115      self.assertNotEqual(wx.op.colocation_groups(), gw2.op.colocation_groups())
116
117  def testColocateGradientsWithAggregationInMultipleDevices(self):
118    with ops.Graph().as_default() as g:
119      with g.device("/device:GPU:1"):
120        w = constant(1.0, shape=[1, 1])
121      x = constant(1.0, shape=[1, 2])
122      y = constant(1.0, shape=[1, 2])
123      with g.device("/task:1"):
124        wx = math_ops.matmul(w, x)
125      with g.device("/task:2"):
126        wy = math_ops.matmul(w, y)
127      with g.device("/device:GPU:0"):
128        z = wx + wy
129
130      gw1 = gradients.gradients(z, [w], colocate_gradients_with_ops=True)[0]
131      self.assertEqual(gw1.op.colocation_groups(), w.op.colocation_groups())
132
133      gw2 = gradients.gradients(z, [w], colocate_gradients_with_ops=False)[0]
134      self.assertNotEqual(w.op.colocation_groups(), gw2.op.colocation_groups())
135
136  def testColocateGradientsWithGateGradients(self):
137    if not test_util.is_gpu_available():
138      self.skipTest("No GPU available")
139    with ops.Graph().as_default() as g:
140      with g.device("/device:CPU:0"):
141        x = constant(1.0, shape=[1, 1])
142        y = constant(1.0, shape=[1, 1])
143        s = x + y
144      with g.device("/device:GPU:0"):
145        z = math_ops.reduce_sum(s)
146
147      gz_x = gradients.gradients(z, [x], colocate_gradients_with_ops=True,
148                                 gate_gradients=True)[0]
149
150      # Make sure the placer doesn't complain.
151      self.evaluate(gz_x)
152
153  def testBoundaryStop(self):
154    # Test that we don't differentiate 'x'. The gradient function for 'x' is
155    # set explicitly to None so we will get an exception if the gradient code
156    # tries to differentiate 'x'.
157    with ops.Graph().as_default():
158      c = constant(1.0)
159      x = array_ops.identity(c)
160      y = x + 1.0
161      z = y + 1
162      grads = gradients.gradients(z, [x])
163      self.assertTrue(all(x is not None for x in grads))
164
165  @test_util.run_v1_only("b/120545219")
166  def testBoundaryContinue(self):
167    # Test that we differentiate both 'x' and 'y' correctly when x is a
168    # predecessor of y.
169    with self.cached_session():
170      x = constant(1.0)
171      y = x * 2.0
172      z = y * 3.0
173      grads = gradients.gradients(z, [x, y])
174      self.assertTrue(all(x is not None for x in grads))
175      self.assertEqual(6.0, grads[0].eval())
176
177  @test_util.run_v1_only("b/120545219")
178  def testAggregationMethodAccumulateN(self):
179    with self.cached_session():
180      x = constant(1.0)
181      y = x * 2.0
182      z = y + y + y + y + y + y + y + y + y + y
183      grads = gradients.gradients(
184          z, [x, y],
185          aggregation_method=gradients.AggregationMethod.
186          EXPERIMENTAL_ACCUMULATE_N)
187      self.assertTrue(all(x is not None for x in grads))
188      self.assertEqual(20.0, grads[0].eval())
189      self.assertEqual(10.0, grads[1].eval())
190
191  @test_util.run_v1_only("b/120545219")
192  def testAggregationMethodAddN(self):
193    with self.cached_session():
194      x = constant(1.0)
195      y = x * 2.0
196      z = y + y + y + y + y + y + y + y + y + y
197      grads = gradients.gradients(
198          z, [x, y], aggregation_method=gradients.AggregationMethod.ADD_N)
199      self.assertTrue(all(x is not None for x in grads))
200      self.assertEqual(20.0, grads[0].eval())
201      self.assertEqual(10.0, grads[1].eval())
202
203  @test_util.run_v1_only("b/120545219")
204  def testAggregationMethodTree(self):
205    with self.cached_session():
206      x = constant(1.0)
207      y = x * 2.0
208      z = y + y + y + y + y + y + y + y + y + y
209      grads = gradients.gradients(
210          z, [x, y],
211          aggregation_method=gradients.AggregationMethod.EXPERIMENTAL_TREE)
212      self.assertTrue(all(x is not None for x in grads))
213      self.assertEqual(20.0, grads[0].eval())
214      self.assertEqual(10.0, grads[1].eval())
215
216  def testNoGradientForStringOutputs(self):
217    with ops.Graph().as_default():
218
219      def _TestOpGrad(_, float_grad, string_grad):
220        """Gradient function for TestStringOutput."""
221        self.assertEqual(float_grad.dtype, dtypes.float32)
222        self.assertFalse(string_grad)
223        return float_grad
224
225      ops.RegisterGradient("TestStringOutput")(_TestOpGrad)
226
227      c = constant(1.0)
228      x, _ = test_ops.test_string_output(c)
229      z = x * 2.0
230      w = z * 3.0
231      grads = gradients.gradients(z, [c])
232      self.assertIsInstance(grads[0], ops.Tensor)
233      grads = gradients.gradients(w, [c])
234      self.assertIsInstance(grads[0], ops.Tensor)
235
236  def testNoGradientForStringOutputsWithOpNamespace(self):
237    with ops.Graph().as_default():
238
239      def _TestOpGrad(_, float_grad, string_grad):
240        """Gradient function for TestStringOutput."""
241        self.assertEqual(float_grad.dtype, dtypes.float32)
242        self.assertFalse(string_grad)
243        return float_grad
244
245      ops.RegisterGradient("Namespace>TestStringOutput")(_TestOpGrad)
246
247      c = constant(1.0)
248      x, _ = test_ops.namespace_test_string_output(c)
249      z = x * 2.0
250      w = z * 3.0
251      grads = gradients.gradients(z, [c])
252      self.assertIsInstance(grads[0], ops.Tensor)
253      grads = gradients.gradients(w, [c])
254      self.assertIsInstance(grads[0], ops.Tensor)
255
256  def testSingletonIndexedSlices(self):
257    with ops.Graph().as_default():
258      x = array_ops.placeholder(dtypes.float32)
259      y = array_ops.identity(x)
260      dy = ops.IndexedSlices(
261          array_ops.placeholder(dtypes.float32),
262          array_ops.placeholder(dtypes.int32))
263      dx, = gradients.gradients(y, x, grad_ys=dy)
264      # The IndexedSlices gradient of tf.identity is the identity map.
265      with self.cached_session() as sess:
266        vdx, vdy = sess.run(
267            [dx, dy], feed_dict={x: [1.0], dy.indices: [0], dy.values: [2.0]})
268      self.assertEqual(vdx, vdy)
269
270  @test_util.run_v1_only("b/120545219")
271  def testNonDifferentiableSwitchInWhileLoop(self):
272    with ops.Graph().as_default():
273      v = array_ops.placeholder(dtypes.float32, [])
274
275      def _Step(i, a, ta):
276        a += math_ops.cast(v, dtypes.int32)
277        return (i + 1, a, ta.write(i, a))
278
279      n = 4
280      i, _, ta = control_flow_ops.while_loop(
281          lambda i, *_: i < n,
282          _Step, [0, 0, tensor_array_ops.TensorArray(
283              dtypes.int32, size=n)])
284      target = ta.read(i - 1)
285      grad, = gradients.gradients(target, v)
286      self.assertIsNone(grad)
287
288  def testVariableReadValueGradient(self):
289    with ops.Graph().as_default():
290      init = constant_op.constant(100.0)
291      var = variables.Variable(init)
292      gradient = gradients.gradients(var.read_value(), var)
293      self.assertIsNotNone(gradient)
294
295  @parameterized.parameters(dtypes.float32, dtypes.float64)
296  def testVariableDefaultGrad(self, dtype):
297    with ops.Graph().as_default():
298      init = constant_op.constant(100.0, dtype=dtype)
299      var = variables.Variable(init)
300      dummy_const = constant_op.constant(0.0)
301      gradient = gradients.gradients(
302          dummy_const,
303          var,
304          unconnected_gradients=unconnected_gradients.UnconnectedGradients.ZERO
305      )[0]
306      self.assertEqual(gradient.dtype, dtype)
307      self.assertIsNotNone(gradient)
308
309  def testVariableAsGraphElementGradient(self):
310    with ops.Graph().as_default() as graph:
311      init = constant_op.constant(100.0)
312      var = variables.Variable(init)
313      gradient = gradients.gradients(graph.as_graph_element(var), var)
314      self.assertIsNotNone(gradient)
315
316  @test_util.run_v1_only("b/120545219")
317  def testVariableRefGradient(self):
318    with ops.Graph().as_default():
319      init = constant_op.constant(100.0)
320      var = variables.VariableV1(init)
321      gradient = gradients.gradients(var._ref(), var)
322      self.assertIsNotNone(gradient)
323
324  @test_util.run_v1_only("b/120545219")
325  def testDependentYs(self):
326    with self.cached_session():
327      x = constant_op.constant(3.0)
328      y = math_ops.square(x)
329      y1 = math_ops.square(y)
330      y2 = math_ops.square(y1)
331      g = gradients.gradients([y, y2], x)
332      self.assertAllClose(17502.0, g[0])
333      g = gradients.gradients(y + y2, x)
334      self.assertAllClose(17502.0, g[0])
335      z = array_ops.identity(y)
336      z2 = array_ops.identity(y2)
337      g = gradients.gradients([z, z2], x)
338      self.assertAllClose(17502.0, g[0])
339
340  @test_util.run_v1_only("b/120545219")
341  def testPartialDerivatives(self):
342    with self.cached_session():
343      x = constant_op.constant(1.)
344      y = 2 * x
345      z = x + y
346      totalg = gradients.gradients(z, [x, y])
347      self.assertEqual([3.0, 1.0], [g.eval() for g in totalg])
348      partialg = gradients.gradients(z, [x, y], stop_gradients=[x, y])
349      self.assertEqual([1.0, 1.0], [g.eval() for g in partialg])
350
351  @test_util.run_v1_only("b/120545219")
352  def testStopGradients(self):
353    def _MakeGraph(rng, stop_gradients=()):
354      def _FunctionOf(xs, k=3):
355        return ops.convert_to_tensor(
356            sum(math_ops.matmul(rng.rand(k, k), x) for x in xs)
357            + rng.rand(k, k))
358
359      a = _FunctionOf([])
360      if "a" in stop_gradients: a = array_ops.stop_gradient(a)
361      b = _FunctionOf([a])
362      if "b" in stop_gradients: b = array_ops.stop_gradient(b)
363      c = _FunctionOf([a, b])
364      if "c" in stop_gradients: c = array_ops.stop_gradient(c)
365      d = _FunctionOf([b, c])
366      if "d" in stop_gradients: d = array_ops.stop_gradient(d)
367      return dict(a=a, b=b, c=c, d=d)
368
369    def _Gradients(ys, xs, **kwargs):
370      dydxs = gradients.gradients(ys, xs, **kwargs)
371      dydxs = [0. * x if dydx is None else dydx
372               for x, dydx in zip(xs, dydxs)]
373      return dydxs
374
375    seed = np.random.randint(1000)
376    cases = []
377    subsets = [""] + "a b c d ab ac ad bc bd cd abc abd acd bcd abcd".split()
378    graph = _MakeGraph(np.random.RandomState(seed))
379    for constants in subsets:
380      graph_with_stops = _MakeGraph(np.random.RandomState(seed), constants)
381      for variables_ in subsets:
382        # compute the gradient when stopped using tf.stop_gradients
383        grad1 = _Gradients([graph_with_stops["d"]],
384                           [graph_with_stops[v] for v in variables_])
385        # compute the gradient when stopped using the stop_gradients kwarg
386        grad2 = _Gradients([graph["d"]],
387                           [graph[v] for v in variables_],
388                           stop_gradients=[graph[v] for v in constants])
389        cases.append(dict(grad1=grad1, grad2=grad2,
390                          constants=constants, variables=variables_))
391
392    # evaluate all tensors in one call to session.run for speed
393    with self.cached_session() as sess:
394      results = sess.run([(case["grad1"], case["grad2"]) for case in cases])
395
396    for (npgrad1, npgrad2), case in zip(results, cases):
397      for a, b in zip(npgrad1, npgrad2):
398        np.testing.assert_allclose(a, b)
399
400  def testUnconnectedGradientsNoneUnconnectedGradients(self):
401    with ops.Graph().as_default():
402      x = constant(1.0, shape=[2, 2])
403      y = constant(3.0, shape=[3, 1])
404      grad = gradients.gradients(
405          [y], [x], unconnected_gradients="none")
406    self.assertIsNone(grad[0])
407
408  def testUnconnectedGradientsZerosUnconnectedGradients(self):
409    with ops.Graph().as_default():
410      x = constant(1.0, shape=[2, 2])
411      y = constant(3.0, shape=[3, 1])
412      grads = gradients.gradients(
413          [y], [x], unconnected_gradients="zero")
414
415      self.assertAllEqual([[0.0, 0.0], [0.0, 0.0]], self.evaluate(grads)[0])
416
417  def testUnconnectedGradientsZeroConnectedGradients(self):
418    with ops.Graph().as_default():
419      x = constant(1.0)
420      y = x * 3.0
421      grad = gradients.gradients(
422          [y], [x], unconnected_gradients="zero")
423
424      self.assertEqual(3.0, self.evaluate(grad)[0])
425
426  def testUnknownUnconnectedGradientsValueGiven(self):
427    with ops.Graph().as_default():
428      x = constant(1.0)
429      y = constant(1.0)
430      with self.assertRaisesRegex(
431          ValueError, "Unknown value for unconnected_gradients: 'nonsense'"):
432        gradients.gradients([y], [x], unconnected_gradients="nonsense")
433
434  @parameterized.parameters(unconnected_gradients.UnconnectedGradients.ZERO,
435                            unconnected_gradients.UnconnectedGradients.NONE)
436  def testUnconnectedOpWithMultipleOutputs(self, unconnected_gradients_val):
437    with ops.Graph().as_default():
438      #  a    b
439      #  |    |
440      # IdentityN
441      #  |    |
442      #  c    d
443      #  |
444      # Identity
445      #  |
446      #  e
447      a = constant_op.constant(1.0)
448      b = constant_op.constant(1.0)
449      c, d = array_ops.identity_n([a, b])
450      e = array_ops.identity(c)
451      # The aggregated grads for the IdentityN node would look like
452      # [Tensor, None]. We expect this None to be converted to zeros.
453      output = gradients.gradients(
454          e, d, unconnected_gradients=unconnected_gradients_val)
455      if (unconnected_gradients_val ==
456          unconnected_gradients.UnconnectedGradients.ZERO):
457        self.assertIsNotNone(output[0])
458      else:
459        self.assertIsNone(output[0])
460
461  @parameterized.parameters(unconnected_gradients.UnconnectedGradients.ZERO,
462                            unconnected_gradients.UnconnectedGradients.NONE)
463  def testUnconnectedOpWithMultipleOutputsStopGradient(
464      self, unconnected_gradients_val):
465    with ops.Graph().as_default():
466      #  a    b
467      #  |    |
468      # IdentityN
469      #  |    |
470      #  c    d
471      #  |    |
472      #  SG   |
473      #  |    |
474      #   \  /
475      #    +
476      #    e
477      a = constant_op.constant(1.0)
478      b = constant_op.constant(1.0)
479      c, d = array_ops.identity_n([a, b])
480      e = array_ops.stop_gradient(c) + d
481      # The aggregated grads for the IdentityN node would look like
482      # [None, Tensor]. We expect this None to be converted to zeros.
483      output = gradients.gradients(
484          e, c, unconnected_gradients=unconnected_gradients_val)
485      if (unconnected_gradients_val ==
486          unconnected_gradients.UnconnectedGradients.ZERO):
487        self.assertIsNotNone(output[0])
488      else:
489        self.assertIsNone(output[0])
490
491
492class FunctionGradientsTest(test_util.TensorFlowTestCase):
493
494  @classmethod
495  def XSquarePlusB(cls, x, b):
496    return x * x + b
497
498  @classmethod
499  def XSquarePlusBGradient(cls, x, b, g):
500    # Perturb gradients (multiply by 2), so we can test that this was called.
501    g *= 2.0
502    return g * 2.0 * x, g
503
504  @classmethod
505  def _PythonGradient(cls, op, grad):
506    # Perturb gradients (multiply by 3), so we can test that this was called.
507    grad *= 3.0
508    return grad * op.inputs[0] * 2.0, grad
509
510  @classmethod
511  def _GetFunc(cls, **kwargs):
512    return framework_function.Defun(dtypes.float32, dtypes.float32, **
513                                    kwargs)(cls.XSquarePlusB)
514
515  def _GetFuncGradients(self, f, x_value, b_value):
516    x = constant_op.constant(x_value, name="x")
517    b = constant_op.constant(b_value, name="b")
518
519    y = f(x, b)
520    grads = gradients.gradients(y, [x, b])
521
522    return self.evaluate(grads)
523
524  def testFunctionGradientsBasic(self):
525    g = ops.Graph()
526    with g.as_default():
527      f = self._GetFunc()
528      # Get gradients (should add SymbolicGradient node for function).
529      grads = self._GetFuncGradients(f, [2.0], [1.0])
530      self.assertAllEqual([4.0], grads[0])
531      self.assertAllEqual([1.0], grads[1])
532
533  def testFunctionGradientsComposition(self):
534    with ops.Graph().as_default():
535      f = self._GetFunc()
536      x = constant_op.constant([2.0], name="x")
537      b1 = constant_op.constant([1.0], name="b1")
538      b2 = constant_op.constant([1.0], name="b2")
539
540      y = f(f(x, b1), b2)
541      # Build gradient graph (should add SymbolicGradient node for function).
542      grads = gradients.gradients(y, [x, b1])
543
544      self.assertAllEqual([40.0], self.evaluate(grads)[0])
545      self.assertAllEqual([10.0], self.evaluate(grads)[1])
546
547  def testFunctionGradientsWithGradFunc(self):
548    g = ops.Graph()
549    with g.as_default():
550      grad_func = framework_function.Defun(dtypes.float32, dtypes.float32,
551                                           dtypes.float32)(
552                                               self.XSquarePlusBGradient)
553      f = self._GetFunc(grad_func=grad_func)
554      # Get gradients (should add SymbolicGradient node for function, which
555      # uses the grad_func above, which multiplies all gradients by 2).
556      grads = self._GetFuncGradients(f, [2.0], [1.0])
557      self.assertAllEqual([4.0 * 2], grads[0])
558      self.assertAllEqual([1.0 * 2], grads[1])
559
560  def testFunctionGradientWithRegistration(self):
561    g = ops.Graph()
562    with g.as_default():
563      f = self._GetFunc(python_grad_func=self._PythonGradient)
564      # Get gradients, using the python gradient function. It multiplies the
565      # gradients by 3.
566      grads = self._GetFuncGradients(f, [2.0], [1.0])
567      self.assertAllEqual([4.0 * 3], grads[0])
568      self.assertAllEqual([1.0 * 3], grads[1])
569
570  def testFunctionGradientWithGradFuncAndRegistration(self):
571    g = ops.Graph()
572    with g.as_default():
573      grad_func = framework_function.Defun(dtypes.float32, dtypes.float32,
574                                           dtypes.float32)(
575                                               self.XSquarePlusBGradient)
576      with self.assertRaisesRegex(ValueError, "Gradient defined twice"):
577        f = self._GetFunc(
578            grad_func=grad_func, python_grad_func=self._PythonGradient)
579        f.add_to_graph(ops.Graph())
580
581  def testGradientWrtCaptured(self):
582    with ops.Graph().as_default():
583      x = constant_op.constant(1.0, name="x")
584
585      @function.defun()
586      def Foo():
587        y = math_ops.multiply(x, 2.0, name="y")
588        g = gradients_impl.gradients(y, x)
589        return g[0]
590
591      f = Foo()
592
593      self.assertEqual(self.evaluate(f), 2.0)
594
595  def testGradientOfCaptured(self):
596    with ops.Graph().as_default():
597      x = constant_op.constant(1.0, name="x")
598      y = math_ops.multiply(x, 2.0, name="y")
599
600      @framework_function.Defun()
601      def Foo():
602        g = gradients_impl.gradients(y, x)
603        return g[0]
604
605      f = Foo()
606
607      self.assertEqual(self.evaluate(f), 2.0)
608
609  def testCapturedResourceVariable(self):
610    with ops.Graph().as_default():
611      var = resource_variable_ops.ResourceVariable(1.0, name="var")
612
613      @function.defun()
614      def Foo():
615        y = math_ops.multiply(var, 2.0, name="y")
616        g = gradients_impl.gradients(y, var)
617        return g[0]
618
619      f = Foo()
620
621      self.evaluate(variables.global_variables_initializer())
622      self.assertEqual(self.evaluate(f), 2.0)
623
624  def testCapturedNested(self):
625    with ops.Graph().as_default():
626      x1 = constant_op.constant(1.0, name="x1")
627      x2 = constant_op.constant(2.0, name="x2")
628      x3 = math_ops.multiply(x1, x2, name="x3")
629
630      @function.defun()
631      def Outer():
632        outer1 = array_ops.identity(x1, name="outer1")
633
634        @function.defun()
635        def Inner():
636          inner1 = array_ops.identity(outer1, name="inner1")
637          inner2 = array_ops.identity(x2, name="inner2")
638          inner3 = array_ops.identity(x3, name="inner3")
639          return gradients_impl.gradients([inner1, inner2, inner3, x1],
640                                          [x1, x2])
641
642        return Inner()
643
644      x1_grad, x2_grad = Outer()
645
646      # 1.0 + None + 2.0 + 1.0 = 4.0
647      self.assertEqual(self.evaluate(x1_grad), 4.0)
648      # None + 1.0 + 1.0 + None = 2.0
649      self.assertEqual(self.evaluate(x2_grad), 2.0)
650
651  def testCapturedFromFunction(self):
652    with ops.Graph().as_default():
653      x = constant_op.constant(1.0, name="x")
654
655      @function.defun()
656      def Outer():
657        y = math_ops.multiply(x, 2.0, name="y")
658
659        @function.defun()
660        def Inner():
661          z = math_ops.multiply(y, 3.0, name="z")
662          g = gradients_impl.gradients(z, y)
663          return g[0]
664
665        return Inner()
666
667      z_grad = Outer()
668
669      self.assertEqual(self.evaluate(z_grad), 3.0)
670
671  def testCapturedEagerTensors(self):
672    # Test that we can handle captured eager tensors unrelated to the gradient
673    # computation (i.e. we need to ignore them).
674    # TODO(skyewm): make it an error if you try to take the gradient wrt a
675    # captured EagerTensor
676    with context.eager_mode():
677      c = constant_op.constant(2.0, name="c")
678
679      @function.defun
680      def Foo():
681        x = constant_op.constant(10.0, name="x")
682        y = math_ops.multiply(x, c, name="y")
683        # Regression test for b/122564611.
684        z = math_ops.multiply(c, y, name="z")
685        g = gradients_impl.gradients(z, x)
686        return g[0]
687
688      self.assertEqual(Foo().numpy(), 4.0)
689
690
691class StopGradientTest(test_util.TensorFlowTestCase):
692
693  def testStopGradient(self):
694    with ops.Graph().as_default():
695      inp = constant(1.0, shape=[100, 32], name="in")
696      out = array_ops.stop_gradient(inp)
697      igrad = gradients.gradients(out, inp)[0]
698    assert igrad is None
699
700
701class PreventGradientTest(test_util.TensorFlowTestCase):
702
703  def testPreventGradient(self):
704    with ops.Graph().as_default():
705      inp = constant(1.0, shape=[100, 32], name="in")
706      out = array_ops.prevent_gradient(inp)
707      with self.assertRaisesRegex(LookupError, "explicitly disabled"):
708        _ = gradients.gradients(out, inp)
709
710
711class HessianVectorProductTest(test_util.TensorFlowTestCase):
712
713  @test_util.run_v1_only("b/120545219")
714  def testHessianVectorProduct(self):
715    # Manually compute the Hessian explicitly for a low-dimensional problem
716    # and check that HessianVectorProduct matches multiplication by the
717    # explicit Hessian.
718    # Specifically, the Hessian of f(x) = x^T A x is
719    # H = A + A^T.
720    # We expect HessianVectorProduct(f(x), x, v) to be H v.
721    m = 4
722    rng = np.random.RandomState([1, 2, 3])
723    mat_value = rng.randn(m, m).astype("float32")
724    v_value = rng.randn(m, 1).astype("float32")
725    x_value = rng.randn(m, 1).astype("float32")
726    hess_value = mat_value + mat_value.T
727    hess_v_value = np.dot(hess_value, v_value)
728    for use_gpu in [False, True]:
729      with self.cached_session(use_gpu=use_gpu):
730        mat = constant_op.constant(mat_value)
731        v = constant_op.constant(v_value)
732        x = constant_op.constant(x_value)
733        mat_x = math_ops.matmul(mat, x, name="Ax")
734        x_mat_x = math_ops.matmul(array_ops.transpose(x), mat_x, name="xAx")
735        hess_v = gradients_impl._hessian_vector_product(x_mat_x, [x], [v])[0]
736        hess_v_actual = self.evaluate(hess_v)
737      self.assertAllClose(hess_v_value, hess_v_actual)
738
739
740class HessianTest(test_util.TensorFlowTestCase):
741
742  @test_util.run_v1_only("b/120545219")
743  def testHessian1D(self):
744    # Manually compute the Hessian explicitly for a low-dimensional problem
745    # and check that `hessian` matches. Specifically, the Hessian of
746    # f(x) = x^T A x is H = A + A^T.
747    m = 4
748    rng = np.random.RandomState([1, 2, 3])
749    mat_value = rng.randn(m, m).astype("float32")
750    x_value = rng.randn(m).astype("float32")
751    hess_value = mat_value + mat_value.T
752    with self.session():
753      mat = constant_op.constant(mat_value)
754      x = constant_op.constant(x_value)
755      x_mat_x = math_ops.reduce_sum(x[:, None] * mat * x[None, :])
756      hess = gradients.hessians(x_mat_x, x)[0]
757      hess_actual = self.evaluate(hess)
758    self.assertAllClose(hess_value, hess_actual)
759
760  @test_util.run_v1_only("b/120545219")
761  def testHessian1D_multi(self):
762    # Test the computation of the hessian with respect to multiple tensors
763    m = 4
764    n = 3
765    rng = np.random.RandomState([1, 2, 3])
766    mat_values = [rng.randn(m, m).astype("float32") for _ in range(n)]
767    x_values = [rng.randn(m).astype("float32") for _ in range(n)]
768    hess_values = [mat_value + mat_value.T for mat_value in mat_values]
769    with self.session():
770      mats = [constant_op.constant(mat_value) for mat_value in mat_values]
771      xs = [constant_op.constant(x_value) for x_value in x_values]
772      xs_mats_xs = [
773          math_ops.reduce_sum(x[:, None] * mat * x[None, :])
774          for x, mat in zip(xs, mats)
775      ]
776      hessians = gradients.hessians(xs_mats_xs, xs)
777      hessians_actual = [hess.eval() for hess in hessians]
778    for hess_value, hess_actual in zip(hess_values, hessians_actual):
779      self.assertAllClose(hess_value, hess_actual)
780
781  @test_util.run_v1_only("b/120545219")
782  def testHessianInvalidDimension(self):
783    for shape in [(10, 10), None]:
784      with self.cached_session():
785        x = array_ops.placeholder(dtypes.float32, shape)
786        # Expect a ValueError because the dimensions are wrong
787        with self.assertRaises(ValueError):
788          gradients.hessians(x, x)
789
790  @test_util.run_v1_only("b/120545219")
791  def testHessian2D_square_matrix(self):
792    # Manually compute the Hessian explicitly for a low-dimensional problem
793    # and check that `hessian` matches. Specifically, the Hessian of
794    # f(x) = 1/2 * x^T * x is H = constant (block identity matrix)
795    m = 3
796    rng = np.random.RandomState([1, 2, 3])
797    x_value = rng.randn(m, m).astype("float32")
798    with self.session():
799      x = constant_op.constant(x_value)
800      x_square = math_ops.reduce_sum(
801          math_ops.matmul(array_ops.transpose(x), x) * 0.5
802      )
803      hess = gradients.hessians(x_square, x)[0]
804      hess_actual = self.evaluate(hess)
805    hess_value = np.bmat([
806        [elem*np.ones((m, m)) for elem in vec]
807        for vec in np.eye(m)
808    ]).astype("float32")
809    self.assertAllEqual((m, m, m, m), hess_actual.shape)
810    self.assertAllClose(hess_value, hess_actual.reshape((m * m, m * m)))
811
812  @test_util.run_v1_only("b/120545219")
813  def testHessian2D_non_square_matrix(self):
814    m = 3
815    n = 4
816    rng = np.random.RandomState([1, 2, 3])
817    x_value = rng.randn(m, n).astype("float32")
818    with self.session():
819      x = constant_op.constant(x_value)
820      x_square = math_ops.reduce_sum(
821          math_ops.matmul(array_ops.transpose(x), x) * 0.5
822      )
823      hess = gradients.hessians(x_square, x)[0]
824      hess_actual = self.evaluate(hess)
825    hess_value = np.bmat([
826        [elem*np.ones((n, n)) for elem in vec]
827        for vec in np.eye(m)
828    ]).astype("float32")
829    self.assertAllEqual((m, n, m, n), hess_actual.shape)
830    self.assertAllClose(hess_value, hess_actual.reshape((m * n, m * n)))
831
832
833class IndexedSlicesToTensorTest(test_util.TensorFlowTestCase):
834
835  @test_util.run_v1_only("b/120545219")
836  def testIndexedSlicesToTensor(self):
837    with self.cached_session():
838      np_val = np.random.rand(4, 4, 4, 4).astype(np.float32)
839      c = constant_op.constant(np_val)
840      c_sparse = math_ops._as_indexed_slices(c)
841      self.assertAllEqual(np_val.shape, c_sparse.dense_shape)
842      c_dense = math_ops.multiply(c_sparse, 1.0)
843      self.assertAllClose(np_val, self.evaluate(c_dense))
844
845  @test_util.run_v1_only("b/120545219")
846  def testIndexedSlicesToTensorList(self):
847    with self.cached_session():
848      numpy_list = []
849      dense_list = []
850      sparse_list = []
851      for _ in range(3):
852        np_val = np.random.rand(4, 4, 4, 4).astype(np.float32)
853        c = constant_op.constant(np_val)
854        c_sparse = math_ops._as_indexed_slices(c)
855        numpy_list.append(np_val)
856        dense_list.append(c)
857        sparse_list.append(c_sparse)
858      packed_dense = array_ops.stack(dense_list)
859      packed_sparse = array_ops.stack(sparse_list)
860      self.assertAllClose(packed_dense, self.evaluate(packed_sparse))
861
862  @test_util.run_v1_only("b/120545219")
863  def testInt64Indices(self):
864    with self.cached_session():
865      np_val = np.random.rand(4, 4, 4, 4).astype(np.float32)
866      c = constant_op.constant(np_val)
867      c_sparse = math_ops._as_indexed_slices(c)
868      c_sparse = ops.IndexedSlices(
869          c_sparse.values,
870          math_ops.cast(c_sparse.indices, dtypes.int64), c_sparse.dense_shape)
871      self.assertAllEqual(np_val.shape, c_sparse.dense_shape)
872      c_dense = math_ops.multiply(c_sparse, 1.0)
873      self.assertAllClose(np_val, self.evaluate(c_dense))
874
875  @test_util.run_v1_only("b/120545219")
876  def testWarnings(self):
877    # TODO(gunan) Reenable after this issue is fixed:
878    # https://github.com/google/protobuf/issues/2812
879    if sys.version_info >= (3, 5):
880      self.skipTest("Skipped test for Python 3.5+")
881
882    # Smaller than the threshold: no warning.
883    c_sparse = ops.IndexedSlices(
884        array_ops.placeholder(dtypes.float32),
885        array_ops.placeholder(dtypes.int32), constant([4, 4, 4, 4]))
886    with warnings.catch_warnings(record=True) as w:
887      math_ops.multiply(c_sparse, 1.0)
888    self.assertEqual(0, len(w))
889
890    # Greater than or equal to the threshold: warning.
891    c_sparse = ops.IndexedSlices(
892        array_ops.placeholder(dtypes.float32),
893        array_ops.placeholder(dtypes.int32), constant([100, 100, 100, 100]))
894    # "always" filter prevents the warning from being suppressed if it was
895    # already triggered in a different test.
896    warnings.simplefilter("always")
897    with warnings.catch_warnings(record=True) as w:
898      math_ops.multiply(c_sparse, 1.0)
899    self.assertEqual(1, len(w))
900    self.assertIn(
901        "with 100000000 elements. This may consume a large amount of memory.",
902        str(w[0].message))
903
904    # Unknown dense shape: warning.
905    c_sparse = ops.IndexedSlices(
906        array_ops.placeholder(dtypes.float32),
907        array_ops.placeholder(dtypes.int32),
908        array_ops.placeholder(dtypes.int32))
909    with warnings.catch_warnings(record=True) as w:
910      math_ops.multiply(c_sparse, 1.0)
911    self.assertEqual(1, len(w))
912    self.assertIn(
913        "of unknown shape. This may consume a large amount of memory.",
914        str(w[0].message))
915
916
917class OnlyRealGradientsTest(test_util.TensorFlowTestCase):
918
919  @test_util.run_v1_only("b/120545219")
920  def testRealOnly(self):
921    x = constant_op.constant(7+3j, dtype=dtypes.complex64)
922    y = math_ops.square(x)
923    with self.assertRaisesRegex(
924        TypeError, r"Gradients of complex tensors must set grad_ys "
925        r"\(y\.dtype = tf\.complex64\)"):
926      gradients.gradients(y, x)
927
928
929class ResourceCondTest(test_util.TensorFlowTestCase):
930
931  @test_util.run_v1_only("b/120545219")
932  def testBasic(self):
933    gamma = resource_variable_ops.ResourceVariable(
934        np.random.random((3,)),
935        dtype="float32", name="gamma")
936
937    inputs = array_ops.ones(shape=(3,), dtype="float32")
938
939    def TestFn():
940      output = inputs + gamma
941      return output
942
943    training = array_ops.placeholder_with_default(True, shape=())
944    output = control_flow_ops.cond(
945        training, TestFn, lambda: inputs)
946
947    loss = output
948
949    grads = gradients.gradients(
950        loss, [gamma])
951    self.assertNotIn(None, grads)
952
953
954class GetDependentVariablesTest(test_util.TensorFlowTestCase):
955
956  def testNoVariables(self):
957    with ops.Graph().as_default():
958      func = lambda x: array_ops.identity(x) + 5.0
959      input_t = constant_op.constant(2.0)
960      result_t = func(input_t)
961      dependent_vars = custom_gradient._get_dependent_variables(
962          [input_t], [result_t])
963
964      # There are no variables.
965      self.assertEqual(dependent_vars, [])
966
967  def testVariablesOutside(self):
968    with ops.Graph().as_default():
969      init = constant_op.constant(100.0)
970      var = variables.Variable(init)
971
972      # The variable is closed over. It should be found.
973      func = lambda x: array_ops.identity(x) + 5.0 + var
974
975      input_t = constant_op.constant(2.0)
976      result_t = func(input_t)
977      dependent_vars = custom_gradient._get_dependent_variables(
978          [input_t], [result_t])
979      self.assertEqual(dependent_vars, [var])
980
981  def testVariableSamePrefix(self):
982    with ops.Graph().as_default():
983      var_name = "my_variable"
984      v_z = variable_scope.get_variable(var_name, shape=())
985      v_o = variable_scope.get_variable(var_name + "_ones", shape=())
986
987      # The variable is closed over. It should be found.
988      func = lambda x: array_ops.identity(x) + 5.0 + v_z + v_o
989
990      input_t = constant_op.constant(2.0)
991      result_t = func(input_t)
992      dependent_vars = custom_gradient._get_dependent_variables(
993          [input_t], [result_t])
994      self.assertEqual(set(dependent_vars), set([v_o, v_z]))
995
996  def testVariablesOutsideButDSeparated(self):
997    with ops.Graph().as_default():
998      init = constant_op.constant(100.0)
999      var = variables.Variable(init)
1000
1001      # The variable is d-separated by the inputs. It should not be found.
1002      input_t = array_ops.identity(var) * 5.0
1003
1004      func = lambda x: array_ops.identity(x) + 5.0
1005      result_t = func(input_t)
1006      dependent_vars = custom_gradient._get_dependent_variables(
1007          [input_t], [result_t])
1008      self.assertEqual(dependent_vars, [])
1009
1010  def testVariablesOutsideAndNonDifferentiable(self):
1011    with ops.Graph().as_default():
1012      init = constant_op.constant(100.0, shape=(5,))
1013      var = variables.Variable(init, shape=(5,))
1014
1015      def _Func(x):
1016        # non-differentiable dependency on var.
1017        # the variable should not be found.
1018        y = array_ops.ones_like(var)
1019        return array_ops.identity(x) + 5.0 + y
1020
1021      input_t = constant_op.constant(2.0)
1022      result_t = _Func(input_t)
1023      dependent_vars = custom_gradient._get_dependent_variables(
1024          [input_t], [result_t])
1025      self.assertEqual(dependent_vars, [])
1026
1027  def testVariablesOutsideAndNonTrainable(self):
1028    with ops.Graph().as_default():
1029      init = constant_op.constant(100.0, shape=(5,))
1030
1031      # Both variables are used in the function but only the trainable one
1032      # should be found.
1033      var_trainable = variables.Variable(init, shape=(5,))
1034      var_nontrainable = variables.Variable(init, shape=(5,), trainable=False)
1035
1036      def _Func(x):
1037        del x
1038        return var_trainable + var_nontrainable
1039
1040      input_t = constant_op.constant(2.0)
1041      result_t = _Func(input_t)
1042      dependent_vars = custom_gradient._get_dependent_variables(
1043          [input_t], [result_t])
1044      self.assertEqual(dependent_vars, [var_trainable])
1045
1046  def testVariablesOutsideAndCustomGradient(self):
1047    with ops.Graph().as_default():
1048      init = constant_op.constant(100.0, shape=(5,))
1049      var = variables.Variable(init, shape=(5,))
1050
1051      @custom_gradient.custom_gradient
1052      def _MyOnesLike(x):
1053        """Dummy version of ones_like which defines a gradient."""
1054
1055        output = array_ops.ones_like(x)
1056
1057        def _Grad(dy):
1058          return array_ops.identity(dy)
1059
1060        return output, _Grad
1061
1062      def _Func(x):
1063        # non-differentiable operation with custom gradient.
1064        # The variable should be found.
1065        y = _MyOnesLike(var)
1066        return array_ops.identity(x) + 5.0 + y
1067
1068      input_t = constant_op.constant(2.0)
1069      result_t = _Func(input_t)
1070      dependent_vars = custom_gradient._get_dependent_variables(
1071          [input_t], [result_t])
1072      self.assertEqual(dependent_vars, [var])
1073
1074
1075class CustomGradientTest(test_util.TensorFlowTestCase, parameterized.TestCase):
1076
1077  def testCustomGradientTrivial(self):
1078
1079    @custom_gradient.custom_gradient
1080    def MyIdentity(x):
1081
1082      def Grad(dy):
1083        return [3 * dy]
1084
1085      return x, Grad
1086
1087    with ops.Graph().as_default():
1088      x = constant(3.)
1089      y = MyIdentity(MyIdentity(x))
1090      dy = gradients.gradients(y, x)[0]
1091      with session.Session():
1092        self.assertEqual(9., self.evaluate(dy))
1093
1094  def testCustomGradient(self):
1095
1096    @custom_gradient.custom_gradient
1097    def MyMultiply(x1, x2):
1098      result = x1 * x2
1099
1100      def Grad(dy):
1101        # Switched the ordering here.
1102        return [dy * x1, dy * x2]
1103
1104      return result, Grad
1105
1106    with ops.Graph().as_default():
1107      x1 = constant(3.)
1108      x2 = constant(5.)
1109      y = MyMultiply(x1, x2)
1110      dy = gradients.gradients(y, [x1, x2])
1111
1112      self.assertAllEqual([3., 5.], self.evaluate(dy))
1113
1114  def testCustomGradientClass(self):
1115
1116    class Model(object):
1117
1118      @custom_gradient.custom_gradient
1119      def Multiply(self, x1, x2):
1120        result = x1 * x2
1121        grad = lambda dy: (dy * x1, dy * x2)
1122        return result, grad
1123
1124    with ops.Graph().as_default():
1125      x1 = constant(3.)
1126      x2 = constant(5.)
1127      m = Model()
1128      y = m.Multiply(x1, x2)
1129      dy = gradients.gradients(y, [x1, x2])
1130      self.assertAllEqual([3., 5.], self.evaluate(dy))
1131
1132  def testCustomGradientErrors(self):
1133
1134    @custom_gradient.custom_gradient
1135    def F(x):
1136
1137      def Grad(_):
1138        raise RuntimeError("x")
1139
1140      return x, Grad
1141
1142    with ops.Graph().as_default():
1143      x = constant(1.0)
1144      y = F(x)
1145      with self.assertRaises(RuntimeError):
1146        gradients.gradients(y, x)
1147
1148  def testCustomGradientWithVariables(self):
1149
1150    @custom_gradient.custom_gradient
1151    def F(x):
1152      out = core_layers.dense(x, 3, use_bias=False)
1153
1154      def Grad(out_grad, variables=None):  # pylint: disable=redefined-outer-name
1155        self.assertEqual(1, len(variables))  # pylint: disable=g-generic-assert
1156        grads = gradients.gradients(out, [x, variables[0]], grad_ys=out_grad)
1157        return grads[0], [array_ops.ones((4, 3))]
1158
1159      return out, Grad
1160
1161    with ops.Graph().as_default():
1162      x = array_ops.ones((2, 4))
1163      with variable_scope.variable_scope("f", use_resource=True) as vs:
1164        y = F(x)
1165        all_vars = vs.global_variables()
1166        assert len(all_vars) == 1
1167      grads = gradients.gradients(y, [x, all_vars[0]])
1168      for g in grads:
1169        self.assertIsNotNone(g)
1170
1171      self.evaluate(variables.global_variables_initializer())
1172      dw = self.evaluate(math_ops.reduce_sum(grads[1]))
1173      self.assertEqual(12., dw)
1174
1175  def testCustomGradientWithCapture(self):
1176    with ops.Graph().as_default():
1177      x = constant(3.)
1178
1179      @framework_function.Defun(dtypes.float32)
1180      def F(y):
1181
1182        @custom_gradient.custom_gradient
1183        def MyMultiply(x1, x2):
1184          result = x1 * x2
1185
1186          def Grad(dy):
1187            # Switched the ordering here.
1188            return [dy * x1, dy * x2]
1189
1190          return result, Grad
1191
1192        res = MyMultiply(x, y)
1193        return gradients.gradients(res, [y])
1194
1195      y = constant(5.)
1196      dy = F(y)
1197      self.assertAllEqual(5., self.evaluate(dy))
1198
1199  def testCustomGradientWithVariablesNoFalsePositives(self):
1200
1201    @custom_gradient.custom_gradient
1202    def F(x):
1203      out = core_layers.dense(x, 3, use_bias=False)
1204
1205      def Grad(out_grad, variables=None):  # pylint: disable=redefined-outer-name
1206        self.assertEqual(1, len(variables))  # pylint: disable=g-generic-assert
1207        grads = gradients.gradients(out, [x, variables[0]], grad_ys=out_grad)
1208        return grads[0], [array_ops.ones((3, 3))]
1209
1210      return out, Grad
1211
1212    with ops.Graph().as_default():
1213      with variable_scope.variable_scope("f", use_resource=True) as vs:
1214        a = array_ops.ones((2, 4))
1215
1216        # Variabes in these layers shouldn't be picked up by the decorator.
1217        b = core_layers.dense(a, 3, use_bias=False)
1218        c = core_layers.dense(b, 3, use_bias=False)
1219        x = core_layers.dense(b, 3, use_bias=False) + c
1220
1221        # Only the variables used in F.
1222        y = F(x)
1223
1224        all_vars = vs.global_variables()
1225        assert len(all_vars) == 4
1226      grads = gradients.gradients(y, [x] + all_vars)
1227      _, var_grads = grads[0], grads[1:]
1228      for g in grads:
1229        self.assertIsNotNone(g)
1230
1231      self.evaluate(variables.global_variables_initializer())
1232      dw = self.evaluate(math_ops.reduce_sum(var_grads[-1]))
1233      self.assertEqual(9., dw)
1234
1235  def testCustomGradientWithVariablesEager(self):
1236    with context.eager_mode():
1237      layer = core_layers.Dense(4, use_bias=False)
1238
1239      @custom_gradient.custom_gradient
1240      def F(x):
1241        out = layer(x)
1242
1243        def Grad(out_grad, variables=None):  # pylint: disable=redefined-outer-name
1244          del out_grad
1245          self.assertEqual(1, len(variables))  # pylint: disable=g-generic-assert
1246          return (array_ops.ones((3, 2)),
1247                  [array_ops.ones((2, 4))])
1248
1249        return out, Grad
1250
1251      x = array_ops.ones((3, 2)) + 2.
1252      with backprop.GradientTape() as tape:
1253        tape.watch(x)
1254        y = F(x)
1255      w, = layer.variables
1256      dx, dw = tape.gradient(y, [x, w])
1257      self.assertEqual(6., math_ops.reduce_sum(dx).numpy())
1258      self.assertEqual(8., math_ops.reduce_sum(dw).numpy())
1259
1260  @test_util.run_v1_only("b/120545219")
1261  def testCustomGradientErrorsWithNonResourceVariables(self):
1262
1263    def F(x, use_resource=False):
1264      with variable_scope.variable_scope("f", use_resource=use_resource):
1265        out = core_layers.dense(x, 4, use_bias=False)
1266
1267      def Grad(out_grad, variables=None):  # pylint: disable=redefined-outer-name
1268        del out_grad
1269        self.assertEqual(1, len(variables))  # pylint: disable=g-generic-assert
1270        return (array_ops.ones((3, 2)), [array_ops.ones((2, 4))])
1271
1272      return out, Grad
1273
1274    @custom_gradient.custom_gradient
1275    def FResource(x):
1276      return F(x, use_resource=True)
1277
1278    @custom_gradient.custom_gradient
1279    def FNonResource(x):
1280      return F(x, use_resource=False)
1281
1282    x = array_ops.ones((3, 2)) + 2.
1283
1284    # Wrapping scope has use_resource=True but inner scope sets to False. Fails.
1285    with variable_scope.variable_scope("vs1", use_resource=True):
1286      with self.assertRaisesWithPredicateMatch(TypeError,
1287                                               "must be `ResourceVariable`s"):
1288        FNonResource(x)
1289
1290    # Wrapping scope has use_resource=False but inner scope sets to True.
1291    # Passes.
1292    with variable_scope.variable_scope("vs2", use_resource=False):
1293      FResource(x)
1294
1295  @parameterized.parameters(True, False)
1296  def testCustomGradientVariablesKwonlyArgs(self, anonymous_varargs):
1297    with context.eager_mode():
1298      x_captured = variables.Variable(3.)  # Used by FuncMult
1299      @custom_gradient.custom_gradient
1300      def FuncMult(x):
1301        def ActualGrad(dy, variables):  # pylint: disable=redefined-outer-name
1302          self.assertLen(variables, 1)
1303          self.assertIs(variables[0], x_captured)
1304          x_captured_grad = 5. * x * dy
1305          return (4. * x_captured * dy, [x_captured_grad])
1306        # Define the returned GradMult, using varargs; "variables" is kwonlyarg
1307        if anonymous_varargs:
1308          def GradMult(dy, *, variables=None):  # pylint: disable=redefined-outer-name
1309            return ActualGrad(dy, variables)
1310        else:
1311          def GradMult(*dys, variables=None):  # pylint: disable=redefined-outer-name
1312            return ActualGrad(dys[0], variables)
1313
1314        return x * x_captured, GradMult
1315
1316      x = variables.Variable(6.)
1317      with backprop.GradientTape(persistent=True) as g:
1318        y = FuncMult(x)
1319      self.assertAllEqual(g.gradient(y, x), 4. * 3.)
1320
1321  def testWithNumpyInputs(self):
1322    with context.eager_mode():
1323
1324      @custom_gradient.custom_gradient
1325      def F(x):
1326        out = x
1327
1328        def Grad(_):
1329          return (None, None)
1330
1331        return out, Grad
1332
1333      x = np.ones((3, 2), dtype=np.float32)
1334      # Smoke test to ensure numpy inputs are accepted
1335      F(x)
1336
1337  @test_util.run_v1_only("b/120545219")
1338  def testRVGradientsDynamicCond(self):
1339    with self.cached_session():
1340      alpha = resource_variable_ops.ResourceVariable(
1341          np.random.random((1,)),
1342          dtype="float32")
1343
1344      conditional = array_ops.placeholder_with_default(True, shape=())
1345      output = control_flow_ops.cond(
1346          conditional, lambda: alpha * 2, lambda: alpha * 3)
1347
1348      g, = gradients_impl.gradients(output, alpha)
1349      self.evaluate(variables.global_variables_initializer())
1350      self.assertAllEqual(g, [2.0])
1351      self.assertAllEqual(g.eval(feed_dict={conditional: False}), [3.0])
1352
1353  def testRecursiveCustomGradient(self):
1354    @custom_gradient.custom_gradient
1355    def F(x):
1356      out = core_layers.dense(x, 3, use_bias=False)
1357
1358      def Grad(out_grad, variables=None):  # pylint: disable=redefined-outer-name
1359        self.assertEqual(1, len(variables))  # pylint: disable=g-generic-assert
1360        grads = gradients.gradients(out, [x, variables[0]], grad_ys=out_grad)
1361        return grads[0], [array_ops.ones((4, 3))]
1362
1363      return out, Grad
1364
1365    @custom_gradient.custom_gradient
1366    def DoubleF(x):
1367      out = F(x)
1368
1369      def Grad(out_grad, variables=None):  # pylint: disable=redefined-outer-name
1370        self.assertEqual(1, len(variables))  # pylint: disable=g-generic-assert
1371        grads = gradients.gradients(out, [x, variables[0]], grad_ys=out_grad)
1372        return grads[0], [array_ops.ones((4, 3))]
1373
1374      return out, Grad
1375    with ops.Graph().as_default():
1376      x = array_ops.ones((2, 4))
1377      with variable_scope.variable_scope("f", use_resource=True) as vs:
1378        y = DoubleF(x)
1379        all_vars = vs.global_variables()
1380        assert len(all_vars) == 1
1381      grads = gradients.gradients(y, [x, all_vars[0]])
1382      for g in grads:
1383        self.assertIsNotNone(g)
1384
1385      self.evaluate(variables.global_variables_initializer())
1386      dw = self.evaluate(math_ops.reduce_sum(grads[1]))
1387      self.assertEqual(12., dw)
1388
1389  @parameterized.named_parameters(
1390      [(("_%s_%s" % (x_struct, y_struct)).replace(" ", "").replace("None", ""),  # pylint: disable=g-complex-comprehension
1391        x_struct, y_struct)
1392       for y_struct in [[None, ()], (None, (), [], (None, ((), None)))]
1393       for x_struct in [(None, ()), (((), ()), [None, None], [], (None, ()))]
1394      ])
1395  @test_util.run_in_graph_and_eager_modes
1396  def testCustomGradientStructuralInputOutput(self, x_struct, y_struct):
1397    """Tests that custom_gradient can handle structured inputs/outputs."""
1398    def Zeros(x):
1399      return nest.map_structure(lambda _: array_ops.zeros([], "float32"), x)
1400    def GetStruct(x):
1401      return nest.map_structure(lambda _: None, x)
1402
1403    def MakeVjp(f, *x):
1404      with backprop.GradientTape(persistent=True) as tape:
1405        tape.watch(nest.flatten(x))
1406        y = f(*x)
1407      def Vjp(dy):
1408        return tape.gradient(y, x, output_gradients=dy)
1409      return y, Vjp
1410
1411    @custom_gradient.custom_gradient
1412    def F(*x):
1413      self.assertEqual(x_struct, GetStruct(x))
1414      def Vjp(*dy):
1415        self.assertEqual(len(nest.flatten(y_struct)),
1416                         len(nest.flatten(dy)))
1417        return nest.flatten(Zeros(x_struct))
1418      return Zeros(y_struct), Vjp
1419
1420    x, dy = Zeros([x_struct, y_struct])
1421    y, vjp = MakeVjp(F, *x)
1422    dx = vjp(dy)
1423    self.assertEqual(x_struct, GetStruct(dx))
1424    self.assertEqual(y_struct, GetStruct(y))
1425
1426
1427class TensorListGradientsTest(test_util.TensorFlowTestCase):
1428
1429  def testDefaultGradYs(self):
1430    with ops.Graph().as_default():
1431      tl = list_ops.empty_tensor_list(
1432          element_dtype=dtypes.float32,
1433          element_shape=ops.convert_to_tensor([], dtype=dtypes.int32))
1434      a = constant(1.0)
1435      tl = list_ops.tensor_list_push_back(tl, a)
1436
1437      grad_tl = list_ops.empty_tensor_list(
1438          element_dtype=dtypes.float32,
1439          element_shape=ops.convert_to_tensor([], dtype=dtypes.int32))
1440      grad_tl = list_ops.tensor_list_push_back(tl, constant(5.0))
1441
1442      grad = gradients.gradients(tl, a, grad_ys=grad_tl)[0]
1443
1444      self.assertEqual(self.evaluate(grad), 5.)
1445
1446
1447class VariablesGradientTest(test_util.TensorFlowTestCase,
1448                            parameterized.TestCase):
1449
1450  def _TestFnVariablesGradient(self, inputs, test_fn, vars_to_grad):
1451    """Returns gradients of `test_model` with respect to `vars_to_grad`."""
1452
1453    test_fn_re = custom_gradient.recompute_grad(test_fn)
1454
1455    with backprop.GradientTape(persistent=True) as tape:
1456      tape.watch(vars_to_grad)
1457      out_re = test_fn_re(inputs, vars_to_grad)
1458      out = test_fn(inputs, vars_to_grad)
1459
1460    grads_re = tape.gradient(out_re, vars_to_grad)
1461    grads = tape.gradient(out, vars_to_grad)
1462
1463    return grads_re, grads
1464
1465  def _grad(self, f, argnums=0):
1466    """Return a function which computes the gradient of `f`."""
1467
1468    def F(*params):
1469      with backprop.GradientTape() as tape:
1470        tape.watch(params)
1471        outputs = f(*params)
1472      return tape.gradient(
1473          outputs,
1474          params[argnums],
1475          unconnected_gradients=unconnected_gradients.UnconnectedGradients.ZERO)
1476
1477    return F
1478
1479  def _test_gradients(self, f, inputs, order, delta=1e-3, rtol=1e-2, atol=1e-6):
1480    """Tests backward jacobians of `f`'s [0, `order`)-order gradients."""
1481    if order < 1:
1482      raise ValueError(
1483          "`order` should be a positive integer, got '{}'.".format(order))
1484    if order > 1:
1485      self._test_gradients(
1486          f=self._grad(f),
1487          inputs=inputs,
1488          order=order - 1,
1489          delta=delta,
1490          rtol=rtol,
1491          atol=atol)
1492    sym_jac_back, num_jac = gradient_checker_v2.compute_gradient(
1493        f, inputs, delta=delta)
1494    self.assertAllClose(num_jac, sym_jac_back, rtol=rtol, atol=atol)
1495
1496  @test_util.run_v2_only
1497  def testCustomGradientRecomputeGradHigherOrder(self):
1498
1499    @custom_gradient.recompute_grad
1500    def F(x):
1501      return math_ops.reduce_prod(math_ops.tanh(x)**2)
1502
1503    self._test_gradients(F, [constant_op.constant([1.])], order=3)
1504
1505  @test_util.run_in_graph_and_eager_modes
1506  def testFnRecompute(self):
1507    """Checks that recompute_grad works grads of function args."""
1508
1509    def TestFn(inputs, input_vars):
1510      return inputs * input_vars
1511
1512    def TestFnSeq(inputs, input_vars):
1513      return (inputs * input_vars, inputs * input_vars * 2.0)
1514
1515    with variable_scope.variable_scope("test", use_resource=True):
1516      test_var = variable_scope.get_variable(
1517          name="test_var",
1518          shape=10,
1519          trainable=True,
1520      )
1521      self.evaluate(test_var.assign(np.ones([10])))
1522      test_input = constant(np.ones((10, 10), dtype=np.float32))
1523
1524      grads_re, grads = self._TestFnVariablesGradient(test_input, TestFn,
1525                                                      test_input)
1526
1527      grads_re = self.evaluate(grads_re)
1528      grads = self.evaluate(grads)
1529      for g, g_re in zip(grads, grads_re):
1530        self.assertAllClose(g, g_re)
1531
1532      grads_re, grads = self._TestFnVariablesGradient(test_input, TestFn,
1533                                                      test_var)
1534      grads_re = self.evaluate(grads_re)
1535      grads = self.evaluate(grads)
1536      for g, g_re in zip(grads, grads_re):
1537        self.assertAllClose(g, g_re)
1538
1539      # Regression test for wrapping sequence outputting functions.
1540      grads_re, grads = self._TestFnVariablesGradient(test_input, TestFnSeq,
1541                                                      test_input)
1542      grads_re = self.evaluate(grads_re)
1543      grads = self.evaluate(grads)
1544      for g, g_re in zip(grads, grads_re):
1545        self.assertAllClose(g, g_re)
1546
1547      grads_re, grads = self._TestFnVariablesGradient(test_input, TestFnSeq,
1548                                                      test_var)
1549      grads_re = self.evaluate(grads_re)
1550      grads = self.evaluate(grads)
1551      for g, g_re in zip(grads, grads_re):
1552        self.assertAllClose(g, g_re)
1553
1554  @parameterized.parameters(set((True, context.executing_eagerly())))
1555  def testFnRecomputeWithScopeGradient(self, use_tape):
1556    """Checks that recompute_grad works with var scope and GradientTape."""
1557
1558    def TestFn(input_t):
1559      with variable_scope.variable_scope("inner_scope"):
1560        test_var = variable_scope.get_variable(
1561            name="test_var",
1562            shape=10,
1563            trainable=True,
1564        )
1565        return input_t * test_var
1566
1567    test_input_t = constant(np.zeros((10, 10), dtype=np.float32))
1568
1569    with variable_scope.variable_scope(
1570        "output_scope", reuse=variable_scope.AUTO_REUSE, use_resource=True):
1571      test_fn_re = custom_gradient.recompute_grad(TestFn)
1572
1573      with test_util.AbstractGradientTape(
1574          use_tape=use_tape, persistent=True) as tape:
1575        out_re = test_fn_re(test_input_t)
1576        out = TestFn(test_input_t)
1577
1578    self.evaluate(variables.global_variables_initializer())
1579    grads_re = tape.gradient(out_re, variables.trainable_variables())
1580    grads = tape.gradient(out, variables.trainable_variables())
1581
1582    grads_re = self.evaluate(grads_re)
1583    grads = self.evaluate(grads)
1584    for g, g_re in zip(grads, grads_re):
1585      self.assertAllClose(g, g_re)
1586
1587  @test_util.run_in_graph_and_eager_modes
1588  def testFnRecomputeSameTensor(self):
1589    """Check recompute_grad when wrapped f called as f(x, x) - b/147369366."""
1590
1591    def TestFnMul(x, y):
1592      return x * y
1593
1594    def TestFnSingleVar(x, y):
1595      # pylint: disable=unused-argument
1596      return x
1597
1598    with variable_scope.variable_scope("test", use_resource=True):
1599      x = array_ops.ones((10))
1600
1601      grads_re, grads = self._TestFnVariablesGradient(x, TestFnMul,
1602                                                      x)
1603      grads_re = self.evaluate(grads_re)
1604      grads = self.evaluate(grads)
1605      for g, g_re in zip(grads, grads_re):
1606        self.assertAllClose(g, g_re)
1607
1608      grads_re, grads = self._TestFnVariablesGradient(x, TestFnSingleVar,
1609                                                      x)
1610      grads_re = self.evaluate(grads_re)
1611      grads = self.evaluate(grads)
1612      for g, g_re in zip(grads, grads_re):
1613        self.assertAllClose(g, g_re)
1614
1615
1616class GradPassThroughTest(test_util.TensorFlowTestCase):
1617
1618  @test_util.run_v1_only("b/120545219")
1619  def test_gradients_v1(self):
1620    x = variable_scope.get_variable(
1621        name="x", shape=(), initializer=init_ops.constant_initializer(1.0),
1622        use_resource=True)
1623    z = variable_scope.get_variable(
1624        name="z", shape=(), initializer=init_ops.constant_initializer(3.0),
1625        use_resource=True)
1626
1627    # Verify that assign op is not differentiable
1628    y = state_ops.assign(x, z**2)
1629    grads = gradients.gradients(y, z)
1630    self.assertIsNone(grads[0])
1631
1632    # Verify that when the (non differentiable) assign op is wrapped with
1633    # grad_pass_through, gradients are correctly forwarded to the inputs.
1634    # Form an input as quadratic function of variable z and check that the
1635    # gradient of output wrt to z is correct.
1636    y = custom_gradient.grad_pass_through(
1637        lambda v: state_ops.assign(x, v))(z**2)
1638    grads = gradients.gradients(y, z)
1639
1640    with self.cached_session():
1641      self.evaluate(variables.global_variables_initializer())
1642      self.assertAllClose(grads[0], 6.0)
1643
1644    # Verify that variables involved in the wrapped op do not receive gradients.
1645    y = custom_gradient.grad_pass_through(lambda v: x * v)(z)
1646    grads = gradients.gradients(y, x)
1647    self.assertIsNone(grads[0])
1648
1649  @test_util.run_v2_only
1650  def test_gradients_v2(self):
1651    x = variables.Variable(1.0, name="x")
1652    z = variables.Variable(3.0, name="z")
1653
1654    # Verify that assign op is not differentiable
1655    with backprop.GradientTape() as tape:
1656      y = x.assign(z**2)
1657    grads = tape.gradient(y, z)
1658    self.assertIsNone(grads)
1659
1660    # Verify that when the (non differentiable) assign op is wrapped with
1661    # grad_pass_through, gradients are correctly forwarded to the inputs.
1662    # Form an input as quadratic function of variable z and check that the
1663    # gradient of output wrt to z is correct.
1664    with backprop.GradientTape() as tape:
1665      y = custom_gradient.grad_pass_through(x.assign)(z**2)
1666    grads = tape.gradient(y, z)
1667    self.assertAllClose(grads, 6.0)
1668
1669    # Verify that variables involved in the wrapped op do not receive gradients.
1670    with backprop.GradientTape() as tape:
1671      y = custom_gradient.grad_pass_through(lambda v: x * v)(z)
1672    grads = tape.gradient(y, x)
1673    self.assertIsNone(grads)
1674
1675
1676if __name__ == "__main__":
1677  googletest.main()
1678