• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for tensorflow.ops.gradients."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import sys
22import warnings
23
24import numpy as np
25
26from tensorflow.python.client import session
27from tensorflow.python.eager import backprop
28from tensorflow.python.eager import context
29from tensorflow.python.eager import function
30from tensorflow.python.framework import constant_op
31from tensorflow.python.framework import dtypes
32from tensorflow.python.framework import function as framework_function
33from tensorflow.python.framework import ops
34from tensorflow.python.framework import test_ops
35from tensorflow.python.framework import test_util
36from tensorflow.python.framework.constant_op import constant
37from tensorflow.python.layers import core as core_layers
38from tensorflow.python.ops import array_grad  # pylint: disable=unused-import
39from tensorflow.python.ops import array_ops
40from tensorflow.python.ops import control_flow_grad  # pylint: disable=unused-import
41from tensorflow.python.ops import control_flow_ops
42from tensorflow.python.ops import custom_gradient
43from tensorflow.python.ops import data_flow_grad  # pylint: disable=unused-import
44from tensorflow.python.ops import data_flow_ops  # pylint: disable=unused-import
45from tensorflow.python.ops import functional_ops  # pylint: disable=unused-import
46from tensorflow.python.ops import gradients
47from tensorflow.python.ops import gradients_impl
48from tensorflow.python.ops import gradients_util
49from tensorflow.python.ops import list_ops
50from tensorflow.python.ops import math_grad  # pylint: disable=unused-import
51from tensorflow.python.ops import math_ops
52from tensorflow.python.ops import nn_grad  # pylint: disable=unused-import
53from tensorflow.python.ops import resource_variable_ops
54from tensorflow.python.ops import state_grad  # pylint: disable=unused-import
55from tensorflow.python.ops import tensor_array_grad  # pylint: disable=unused-import
56from tensorflow.python.ops import tensor_array_ops
57from tensorflow.python.ops import variable_scope
58from tensorflow.python.ops import variables
59from tensorflow.python.ops.nn_ops import bias_add
60from tensorflow.python.platform import googletest
61
62
63class GradientsTest(test_util.TensorFlowTestCase):
64
65  def testGradients(self):
66    with ops.Graph().as_default():
67      inp = constant(1.0, shape=[32, 100], name="in")
68      w = constant(1.0, shape=[100, 10], name="w")
69      b = constant(1.0, shape=[10], name="b")
70      xw = math_ops.matmul(inp, w, name="xw")
71      h = bias_add(xw, b, name="h")
72      w_grad = gradients.gradients(h, w)[0]
73    self.assertEquals("MatMul", w_grad.op.type)
74    self.assertEquals(w_grad.op._original_op, xw.op)
75    self.assertTrue(w_grad.op.get_attr("transpose_a"))
76    self.assertFalse(w_grad.op.get_attr("transpose_b"))
77
78  def testUnusedOutput(self):
79    with ops.Graph().as_default():
80      w = constant(1.0, shape=[2, 2])
81      x = constant(1.0, shape=[2, 2])
82      wx = math_ops.matmul(w, x)
83      split_wx = array_ops.split(value=wx, num_or_size_splits=2, axis=0)
84      c = math_ops.reduce_sum(split_wx[1])
85      gw = gradients.gradients(c, [w])[0]
86    self.assertEquals("MatMul", gw.op.type)
87
88  def testColocateGradients(self):
89    with ops.Graph().as_default() as g:
90      w = constant(1.0, shape=[1, 1])
91      x = constant(1.0, shape=[1, 2])
92      with g.device("/device:GPU:0"):
93        wx = math_ops.matmul(w, x)
94      gw = gradients.gradients(wx, [w], colocate_gradients_with_ops=True)[0]
95    self.assertEqual(gw.op.colocation_groups(), wx.op.colocation_groups())
96
97  def testColocateGradientsWithAggregation(self):
98    with ops.Graph().as_default() as g:
99      with g.device("/device:GPU:1"):
100        w = constant(1.0, shape=[1, 1])
101      x = constant(1.0, shape=[1, 2])
102      y = constant(1.0, shape=[1, 2])
103      wx = math_ops.matmul(w, x)
104      wy = math_ops.matmul(w, y)
105      with g.device("/device:GPU:0"):
106        z = wx + wy
107
108      gw1 = gradients.gradients(z, [w], colocate_gradients_with_ops=True)[0]
109      self.assertEqual(gw1.op.colocation_groups(), wx.op.colocation_groups())
110
111      gw2 = gradients.gradients(z, [w], colocate_gradients_with_ops=False)[0]
112      self.assertTrue(wx.op.colocation_groups() != gw2.op.colocation_groups())
113
114  def testColocateGradientsWithAggregationInMultipleDevices(self):
115    with ops.Graph().as_default() as g:
116      with g.device("/device:GPU:1"):
117        w = constant(1.0, shape=[1, 1])
118      x = constant(1.0, shape=[1, 2])
119      y = constant(1.0, shape=[1, 2])
120      with g.device("/task:1"):
121        wx = math_ops.matmul(w, x)
122      with g.device("/task:2"):
123        wy = math_ops.matmul(w, y)
124      with g.device("/device:GPU:0"):
125        z = wx + wy
126
127      gw1 = gradients.gradients(z, [w], colocate_gradients_with_ops=True)[0]
128      self.assertEqual(gw1.op.colocation_groups(), w.op.colocation_groups())
129
130      gw2 = gradients.gradients(z, [w], colocate_gradients_with_ops=False)[0]
131      self.assertTrue(w.op.colocation_groups() != gw2.op.colocation_groups())
132
133  def testColocateGradientsWithGateGradients(self):
134    if not test_util.is_gpu_available():
135      self.skipTest("No GPU available")
136    with ops.Graph().as_default() as g:
137      with g.device("/device:CPU:0"):
138        x = constant(1.0, shape=[1, 1])
139        y = constant(1.0, shape=[1, 1])
140        s = x + y
141      with g.device("/device:GPU:0"):
142        z = math_ops.reduce_sum(s)
143
144      gz_x = gradients.gradients(z, [x], colocate_gradients_with_ops=True,
145                                 gate_gradients=True)[0]
146      with session.Session():
147        # Make sure the placer doesn't complain.
148        self.evaluate(gz_x)
149
150  def testBoundaryStop(self):
151    # Test that we don't differentiate 'x'. The gradient function for 'x' is
152    # set explicitly to None so we will get an exception if the gradient code
153    # tries to differentiate 'x'.
154    with ops.Graph().as_default():
155      c = constant(1.0)
156      x = array_ops.identity(c)
157      y = x + 1.0
158      z = y + 1
159      grads = gradients.gradients(z, [x])
160      self.assertTrue(all(x is not None for x in grads))
161
162  @test_util.run_v1_only("b/120545219")
163  def testBoundaryContinue(self):
164    # Test that we differentiate both 'x' and 'y' correctly when x is a
165    # predecessor of y.
166    with self.cached_session():
167      x = constant(1.0)
168      y = x * 2.0
169      z = y * 3.0
170      grads = gradients.gradients(z, [x, y])
171      self.assertTrue(all(x is not None for x in grads))
172      self.assertEqual(6.0, grads[0].eval())
173
174  @test_util.run_v1_only("b/120545219")
175  def testAggregationMethodAccumulateN(self):
176    with self.cached_session():
177      x = constant(1.0)
178      y = x * 2.0
179      z = y + y + y + y + y + y + y + y + y + y
180      grads = gradients.gradients(
181          z, [x, y],
182          aggregation_method=gradients.AggregationMethod.
183          EXPERIMENTAL_ACCUMULATE_N)
184      self.assertTrue(all(x is not None for x in grads))
185      self.assertEqual(20.0, grads[0].eval())
186      self.assertEqual(10.0, grads[1].eval())
187
188  @test_util.run_v1_only("b/120545219")
189  def testAggregationMethodAddN(self):
190    with self.cached_session():
191      x = constant(1.0)
192      y = x * 2.0
193      z = y + y + y + y + y + y + y + y + y + y
194      grads = gradients.gradients(
195          z, [x, y], aggregation_method=gradients.AggregationMethod.ADD_N)
196      self.assertTrue(all(x is not None for x in grads))
197      self.assertEqual(20.0, grads[0].eval())
198      self.assertEqual(10.0, grads[1].eval())
199
200  @test_util.run_v1_only("b/120545219")
201  def testAggregationMethodTree(self):
202    with self.cached_session():
203      x = constant(1.0)
204      y = x * 2.0
205      z = y + y + y + y + y + y + y + y + y + y
206      grads = gradients.gradients(
207          z, [x, y],
208          aggregation_method=gradients.AggregationMethod.EXPERIMENTAL_TREE)
209      self.assertTrue(all(x is not None for x in grads))
210      self.assertEqual(20.0, grads[0].eval())
211      self.assertEqual(10.0, grads[1].eval())
212
213  def testNoGradientForStringOutputs(self):
214    with ops.Graph().as_default():
215
216      def _TestOpGrad(_, float_grad, string_grad):
217        """Gradient function for TestStringOutput."""
218        self.assertEquals(float_grad.dtype, dtypes.float32)
219        self.assertFalse(string_grad)
220        return float_grad
221
222      ops.RegisterGradient("TestStringOutput")(_TestOpGrad)
223
224      c = constant(1.0)
225      x, _ = test_ops.test_string_output(c)
226      z = x * 2.0
227      w = z * 3.0
228      grads = gradients.gradients(z, [c])
229      self.assertTrue(isinstance(grads[0], ops.Tensor))
230      grads = gradients.gradients(w, [c])
231      self.assertTrue(isinstance(grads[0], ops.Tensor))
232
233  def testSingletonIndexedSlices(self):
234    with ops.Graph().as_default():
235      x = array_ops.placeholder(dtypes.float32)
236      y = array_ops.identity(x)
237      dy = ops.IndexedSlices(
238          array_ops.placeholder(dtypes.float32),
239          array_ops.placeholder(dtypes.int32))
240      dx, = gradients.gradients(y, x, grad_ys=dy)
241      # The IndexedSlices gradient of tf.identity is the identity map.
242      with self.cached_session() as sess:
243        vdx, vdy = sess.run(
244            [dx, dy], feed_dict={x: [1.0], dy.indices: [0], dy.values: [2.0]})
245      self.assertEqual(vdx, vdy)
246
247  @test_util.run_v1_only("b/120545219")
248  def testNonDifferentiableSwitchInWhileLoop(self):
249    with ops.Graph().as_default():
250      v = array_ops.placeholder(dtypes.float32, [])
251
252      def _Step(i, a, ta):
253        a += math_ops.cast(v, dtypes.int32)
254        return (i + 1, a, ta.write(i, a))
255
256      n = 4
257      i, _, ta = control_flow_ops.while_loop(
258          lambda i, *_: i < n,
259          _Step, [0, 0, tensor_array_ops.TensorArray(
260              dtypes.int32, size=n)])
261      target = ta.read(i - 1)
262      grad, = gradients.gradients(target, v)
263      self.assertIsNone(grad)
264
265  def testVariableReadValueGradient(self):
266    with ops.Graph().as_default():
267      init = constant_op.constant(100.0)
268      var = variables.Variable(init)
269      gradient = gradients.gradients(var.read_value(), var)
270      self.assertIsNotNone(gradient)
271
272  def testVariableAsGraphElementGradient(self):
273    with ops.Graph().as_default() as graph:
274      init = constant_op.constant(100.0)
275      var = variables.Variable(init)
276      gradient = gradients.gradients(graph.as_graph_element(var), var)
277      self.assertIsNotNone(gradient)
278
279  @test_util.run_v1_only("b/120545219")
280  def testVariableRefGradient(self):
281    with ops.Graph().as_default():
282      init = constant_op.constant(100.0)
283      var = variables.VariableV1(init)
284      gradient = gradients.gradients(var._ref(), var)
285      self.assertIsNotNone(gradient)
286
287  @test_util.run_v1_only("b/120545219")
288  def testDependentYs(self):
289    with self.cached_session():
290      x = constant_op.constant(3.0)
291      y = math_ops.square(x)
292      y1 = math_ops.square(y)
293      y2 = math_ops.square(y1)
294      g = gradients.gradients([y, y2], x)
295      self.assertAllClose(17502.0, g[0].eval())
296      g = gradients.gradients(y + y2, x)
297      self.assertAllClose(17502.0, g[0].eval())
298      z = array_ops.identity(y)
299      z2 = array_ops.identity(y2)
300      g = gradients.gradients([z, z2], x)
301      self.assertAllClose(17502.0, g[0].eval())
302
303  @test_util.run_v1_only("b/120545219")
304  def testPartialDerivatives(self):
305    with self.cached_session():
306      x = constant_op.constant(1.)
307      y = 2 * x
308      z = x + y
309      totalg = gradients.gradients(z, [x, y])
310      self.assertEqual([3.0, 1.0], [g.eval() for g in totalg])
311      partialg = gradients.gradients(z, [x, y], stop_gradients=[x, y])
312      self.assertEqual([1.0, 1.0], [g.eval() for g in partialg])
313
314  @test_util.run_v1_only("b/120545219")
315  def testStopGradients(self):
316    def _MakeGraph(rng, stop_gradients=()):
317      def _FunctionOf(xs, k=3):
318        return ops.convert_to_tensor(
319            sum(math_ops.matmul(rng.rand(k, k), x) for x in xs)
320            + rng.rand(k, k))
321
322      a = _FunctionOf([])
323      if "a" in stop_gradients: a = array_ops.stop_gradient(a)
324      b = _FunctionOf([a])
325      if "b" in stop_gradients: b = array_ops.stop_gradient(b)
326      c = _FunctionOf([a, b])
327      if "c" in stop_gradients: c = array_ops.stop_gradient(c)
328      d = _FunctionOf([b, c])
329      if "d" in stop_gradients: d = array_ops.stop_gradient(d)
330      return dict(a=a, b=b, c=c, d=d)
331
332    def _Gradients(ys, xs, **kwargs):
333      dydxs = gradients.gradients(ys, xs, **kwargs)
334      dydxs = [0. * x if dydx is None else dydx
335               for x, dydx in zip(xs, dydxs)]
336      return dydxs
337
338    seed = np.random.randint(1000)
339    cases = []
340    subsets = [""] + "a b c d ab ac ad bc bd cd abc abd acd bcd abcd".split()
341    graph = _MakeGraph(np.random.RandomState(seed))
342    for constants in subsets:
343      graph_with_stops = _MakeGraph(np.random.RandomState(seed), constants)
344      for variables_ in subsets:
345        # compute the gradient when stopped using tf.stop_gradients
346        grad1 = _Gradients([graph_with_stops["d"]],
347                           [graph_with_stops[v] for v in variables_])
348        # compute the gradient when stopped using the stop_gradients kwarg
349        grad2 = _Gradients([graph["d"]],
350                           [graph[v] for v in variables_],
351                           stop_gradients=[graph[v] for v in constants])
352        cases.append(dict(grad1=grad1, grad2=grad2,
353                          constants=constants, variables=variables_))
354
355    # evaluate all tensors in one call to session.run for speed
356    with self.cached_session() as sess:
357      results = sess.run([(case["grad1"], case["grad2"]) for case in cases])
358
359    for (npgrad1, npgrad2), case in zip(results, cases):
360      for a, b in zip(npgrad1, npgrad2):
361        np.testing.assert_allclose(a, b)
362
363  def testUnconnectedGradientsNoneUnconnectedGradients(self):
364    with ops.Graph().as_default():
365      x = constant(1.0, shape=[2, 2])
366      y = constant(3.0, shape=[3, 1])
367      grad = gradients.gradients(
368          [y], [x], unconnected_gradients="none")
369    self.assertIsNone(grad[0])
370
371  def testUnconnectedGradientsZerosUnconnectedGradients(self):
372    with ops.Graph().as_default():
373      x = constant(1.0, shape=[2, 2])
374      y = constant(3.0, shape=[3, 1])
375      grads = gradients.gradients(
376          [y], [x], unconnected_gradients="zero")
377      with self.cached_session() as sess:
378        self.assertAllEqual([[0.0, 0.0], [0.0, 0.0]], self.evaluate(grads)[0])
379
380  def testUnconnectedGradientsZeroConnectedGradients(self):
381    with ops.Graph().as_default():
382      x = constant(1.0)
383      y = x * 3.0
384      grad = gradients.gradients(
385          [y], [x], unconnected_gradients="zero")
386      with self.cached_session() as sess:
387        self.assertEquals(3.0, self.evaluate(grad)[0])
388
389  def testUnknownUnconnectedGradientsValueGiven(self):
390    with ops.Graph().as_default():
391      x = constant(1.0)
392      y = constant(1.0)
393      with self.assertRaisesRegexp(
394          ValueError, "Unknown value for unconnected_gradients: 'nonsense'"):
395        gradients.gradients([y], [x], unconnected_gradients="nonsense")
396
397
398class FunctionGradientsTest(test_util.TensorFlowTestCase):
399
400  @classmethod
401  def XSquarePlusB(cls, x, b):
402    return x * x + b
403
404  @classmethod
405  def XSquarePlusBGradient(cls, x, b, g):
406    # Perturb gradients (multiply by 2), so we can test that this was called.
407    g *= 2.0
408    return g * 2.0 * x, g
409
410  @classmethod
411  def _PythonGradient(cls, op, grad):
412    # Perturb gradients (multiply by 3), so we can test that this was called.
413    grad *= 3.0
414    return grad * op.inputs[0] * 2.0, grad
415
416  @classmethod
417  def _GetFunc(cls, **kwargs):
418    return framework_function.Defun(dtypes.float32, dtypes.float32, **
419                                    kwargs)(cls.XSquarePlusB)
420
421  def _GetFuncGradients(self, f, x_value, b_value):
422    x = constant_op.constant(x_value, name="x")
423    b = constant_op.constant(b_value, name="b")
424
425    y = f(x, b)
426    grads = gradients.gradients(y, [x, b])
427    with self.cached_session() as sess:
428      return sess.run(grads)
429
430  def testFunctionGradientsBasic(self):
431    g = ops.Graph()
432    with g.as_default():
433      f = self._GetFunc()
434      # Get gradients (should add SymbolicGradient node for function).
435      grads = self._GetFuncGradients(f, [2.0], [1.0])
436      self.assertAllEqual([4.0], grads[0])
437      self.assertAllEqual([1.0], grads[1])
438
439  def testFunctionGradientsComposition(self):
440    with ops.Graph().as_default():
441      f = self._GetFunc()
442      x = constant_op.constant([2.0], name="x")
443      b1 = constant_op.constant([1.0], name="b1")
444      b2 = constant_op.constant([1.0], name="b2")
445
446      y = f(f(x, b1), b2)
447      # Build gradient graph (should add SymbolicGradient node for function).
448      grads = gradients.gradients(y, [x, b1])
449
450      with self.cached_session() as sess:
451        self.assertAllEqual([40.0], self.evaluate(grads)[0])
452        self.assertAllEqual([10.0], self.evaluate(grads)[1])
453
454  def testFunctionGradientsWithGradFunc(self):
455    g = ops.Graph()
456    with g.as_default():
457      grad_func = framework_function.Defun(dtypes.float32, dtypes.float32,
458                                           dtypes.float32)(
459                                               self.XSquarePlusBGradient)
460      f = self._GetFunc(grad_func=grad_func)
461      # Get gradients (should add SymbolicGradient node for function, which
462      # uses the grad_func above, which multiplies all gradients by 2).
463      grads = self._GetFuncGradients(f, [2.0], [1.0])
464      self.assertAllEqual([4.0 * 2], grads[0])
465      self.assertAllEqual([1.0 * 2], grads[1])
466
467  def testFunctionGradientWithRegistration(self):
468    g = ops.Graph()
469    with g.as_default():
470      f = self._GetFunc(python_grad_func=self._PythonGradient)
471      # Get gradients, using the python gradient function. It multiplies the
472      # gradients by 3.
473      grads = self._GetFuncGradients(f, [2.0], [1.0])
474      self.assertAllEqual([4.0 * 3], grads[0])
475      self.assertAllEqual([1.0 * 3], grads[1])
476
477  def testFunctionGradientWithGradFuncAndRegistration(self):
478    g = ops.Graph()
479    with g.as_default():
480      grad_func = framework_function.Defun(dtypes.float32, dtypes.float32,
481                                           dtypes.float32)(
482                                               self.XSquarePlusBGradient)
483      with self.assertRaisesRegexp(ValueError, "Gradient defined twice"):
484        f = self._GetFunc(
485            grad_func=grad_func, python_grad_func=self._PythonGradient)
486        f.add_to_graph(ops.Graph())
487
488  def testGradientWrtCaptured(self):
489    with ops.Graph().as_default():
490      x = constant_op.constant(1.0, name="x")
491
492      @function.defun()
493      def Foo():
494        y = math_ops.multiply(x, 2.0, name="y")
495        g = gradients_impl.gradients(y, x)
496        return g[0]
497
498      f = Foo()
499      with self.cached_session() as sess:
500        self.assertEqual(self.evaluate(f), 2.0)
501
502  def testGradientOfCaptured(self):
503    with ops.Graph().as_default():
504      x = constant_op.constant(1.0, name="x")
505      y = math_ops.multiply(x, 2.0, name="y")
506
507      @framework_function.Defun()
508      def Foo():
509        g = gradients_impl.gradients(y, x)
510        return g[0]
511
512      f = Foo()
513      with self.cached_session() as sess:
514        self.assertEqual(self.evaluate(f), 2.0)
515
516  def testCapturedResourceVariable(self):
517    with ops.Graph().as_default():
518      var = resource_variable_ops.ResourceVariable(1.0, name="var")
519
520      @function.defun()
521      def Foo():
522        y = math_ops.multiply(var, 2.0, name="y")
523        g = gradients_impl.gradients(y, var)
524        return g[0]
525
526      f = Foo()
527      with self.cached_session() as sess:
528        self.evaluate(variables.global_variables_initializer())
529        self.assertEqual(self.evaluate(f), 2.0)
530
531  def testCapturedNested(self):
532    with ops.Graph().as_default():
533      x1 = constant_op.constant(1.0, name="x1")
534      x2 = constant_op.constant(2.0, name="x2")
535      x3 = math_ops.multiply(x1, x2, name="x3")
536
537      @function.defun()
538      def Outer():
539        outer1 = array_ops.identity(x1, name="outer1")
540
541        @function.defun()
542        def Inner():
543          inner1 = array_ops.identity(outer1, name="inner1")
544          inner2 = array_ops.identity(x2, name="inner2")
545          inner3 = array_ops.identity(x3, name="inner3")
546          return gradients_impl.gradients([inner1, inner2, inner3, x1],
547                                          [x1, x2])
548
549        return Inner()
550
551      x1_grad, x2_grad = Outer()
552      with self.cached_session() as sess:
553        # 1.0 + None + 2.0 + 1.0 = 4.0
554        self.assertEqual(self.evaluate(x1_grad), 4.0)
555        # None + 1.0 + 1.0 + None = 2.0
556        self.assertEqual(self.evaluate(x2_grad), 2.0)
557
558  def testCapturedFromFunction(self):
559    with ops.Graph().as_default():
560      x = constant_op.constant(1.0, name="x")
561
562      @function.defun()
563      def Outer():
564        y = math_ops.multiply(x, 2.0, name="y")
565
566        @function.defun()
567        def Inner():
568          z = math_ops.multiply(y, 3.0, name="z")
569          g = gradients_impl.gradients(z, y)
570          return g[0]
571
572        return Inner()
573
574      z_grad = Outer()
575      with self.cached_session() as sess:
576        self.assertEqual(self.evaluate(z_grad), 3.0)
577
578  def testCapturedEagerTensors(self):
579    # Test that we can handle captured eager tensors unrelated to the gradient
580    # computation (i.e. we need to ignore them).
581    # TODO(skyewm): make it an error if you try to take the gradient wrt a
582    # captured EagerTensor
583    with context.eager_mode():
584      c = constant_op.constant(2.0, name="c")
585
586      @function.defun
587      def Foo():
588        x = constant_op.constant(10.0, name="x")
589        y = math_ops.multiply(x, c, name="y")
590        # Regression test for b/122564611.
591        z = math_ops.multiply(c, y, name="z")
592        g = gradients_impl.gradients(z, x)
593        return g[0]
594
595      self.assertEqual(Foo().numpy(), 4.0)
596
597
598class StopGradientTest(test_util.TensorFlowTestCase):
599
600  def testStopGradient(self):
601    with ops.Graph().as_default():
602      inp = constant(1.0, shape=[100, 32], name="in")
603      out = array_ops.stop_gradient(inp)
604      igrad = gradients.gradients(out, inp)[0]
605    assert igrad is None
606
607
608class PreventGradientTest(test_util.TensorFlowTestCase):
609
610  def testPreventGradient(self):
611    with ops.Graph().as_default():
612      inp = constant(1.0, shape=[100, 32], name="in")
613      out = array_ops.prevent_gradient(inp)
614      with self.assertRaisesRegexp(LookupError, "explicitly disabled"):
615        _ = gradients.gradients(out, inp)
616
617
618class HessianVectorProductTest(test_util.TensorFlowTestCase):
619
620  @test_util.run_v1_only("b/120545219")
621  def testHessianVectorProduct(self):
622    # Manually compute the Hessian explicitly for a low-dimensional problem
623    # and check that HessianVectorProduct matches multiplication by the
624    # explicit Hessian.
625    # Specifically, the Hessian of f(x) = x^T A x is
626    # H = A + A^T.
627    # We expect HessianVectorProduct(f(x), x, v) to be H v.
628    m = 4
629    rng = np.random.RandomState([1, 2, 3])
630    mat_value = rng.randn(m, m).astype("float32")
631    v_value = rng.randn(m, 1).astype("float32")
632    x_value = rng.randn(m, 1).astype("float32")
633    hess_value = mat_value + mat_value.T
634    hess_v_value = np.dot(hess_value, v_value)
635    for use_gpu in [False, True]:
636      with self.cached_session(use_gpu=use_gpu):
637        mat = constant_op.constant(mat_value)
638        v = constant_op.constant(v_value)
639        x = constant_op.constant(x_value)
640        mat_x = math_ops.matmul(mat, x, name="Ax")
641        x_mat_x = math_ops.matmul(array_ops.transpose(x), mat_x, name="xAx")
642        hess_v = gradients_impl._hessian_vector_product(x_mat_x, [x], [v])[0]
643        hess_v_actual = self.evaluate(hess_v)
644      self.assertAllClose(hess_v_value, hess_v_actual)
645
646
647class HessianTest(test_util.TensorFlowTestCase):
648
649  @test_util.run_v1_only("b/120545219")
650  def testHessian1D(self):
651    # Manually compute the Hessian explicitly for a low-dimensional problem
652    # and check that `hessian` matches. Specifically, the Hessian of
653    # f(x) = x^T A x is H = A + A^T.
654    m = 4
655    rng = np.random.RandomState([1, 2, 3])
656    mat_value = rng.randn(m, m).astype("float32")
657    x_value = rng.randn(m).astype("float32")
658    hess_value = mat_value + mat_value.T
659    with self.session(use_gpu=True):
660      mat = constant_op.constant(mat_value)
661      x = constant_op.constant(x_value)
662      x_mat_x = math_ops.reduce_sum(x[:, None] * mat * x[None, :])
663      hess = gradients.hessians(x_mat_x, x)[0]
664      hess_actual = self.evaluate(hess)
665    self.assertAllClose(hess_value, hess_actual)
666
667  @test_util.run_v1_only("b/120545219")
668  def testHessian1D_multi(self):
669    # Test the computation of the hessian with respect to multiple tensors
670    m = 4
671    n = 3
672    rng = np.random.RandomState([1, 2, 3])
673    mat_values = [rng.randn(m, m).astype("float32") for _ in range(n)]
674    x_values = [rng.randn(m).astype("float32") for _ in range(n)]
675    hess_values = [mat_value + mat_value.T for mat_value in mat_values]
676    with self.session(use_gpu=True):
677      mats = [constant_op.constant(mat_value) for mat_value in mat_values]
678      xs = [constant_op.constant(x_value) for x_value in x_values]
679      xs_mats_xs = [
680          math_ops.reduce_sum(x[:, None] * mat * x[None, :])
681          for x, mat in zip(xs, mats)
682      ]
683      hessians = gradients.hessians(xs_mats_xs, xs)
684      hessians_actual = [hess.eval() for hess in hessians]
685    for hess_value, hess_actual in zip(hess_values, hessians_actual):
686      self.assertAllClose(hess_value, hess_actual)
687
688  @test_util.run_v1_only("b/120545219")
689  def testHessianInvalidDimension(self):
690    for shape in [(10, 10), None]:
691      with self.cached_session(use_gpu=True):
692        x = array_ops.placeholder(dtypes.float32, shape)
693        # Expect a ValueError because the dimensions are wrong
694        with self.assertRaises(ValueError):
695          gradients.hessians(x, x)
696
697  @test_util.run_v1_only("b/120545219")
698  def testHessian2D_square_matrix(self):
699    # Manually compute the Hessian explicitly for a low-dimensional problem
700    # and check that `hessian` matches. Specifically, the Hessian of
701    # f(x) = 1/2 * x^T * x is H = constant (block identity matrix)
702    m = 3
703    rng = np.random.RandomState([1, 2, 3])
704    x_value = rng.randn(m, m).astype("float32")
705    with self.session(use_gpu=True):
706      x = constant_op.constant(x_value)
707      x_square = math_ops.reduce_sum(
708          math_ops.matmul(array_ops.transpose(x), x) * 0.5
709      )
710      hess = gradients.hessians(x_square, x)[0]
711      hess_actual = self.evaluate(hess)
712    hess_value = np.bmat([
713        [elem*np.ones((m, m)) for elem in vec]
714        for vec in np.eye(m)
715    ]).astype("float32")
716    self.assertAllEqual((m, m, m, m), hess_actual.shape)
717    self.assertAllClose(hess_value, hess_actual.reshape((m * m, m * m)))
718
719  @test_util.run_v1_only("b/120545219")
720  def testHessian2D_non_square_matrix(self):
721    m = 3
722    n = 4
723    rng = np.random.RandomState([1, 2, 3])
724    x_value = rng.randn(m, n).astype("float32")
725    with self.session(use_gpu=True):
726      x = constant_op.constant(x_value)
727      x_square = math_ops.reduce_sum(
728          math_ops.matmul(array_ops.transpose(x), x) * 0.5
729      )
730      hess = gradients.hessians(x_square, x)[0]
731      hess_actual = self.evaluate(hess)
732    hess_value = np.bmat([
733        [elem*np.ones((n, n)) for elem in vec]
734        for vec in np.eye(m)
735    ]).astype("float32")
736    self.assertAllEqual((m, n, m, n), hess_actual.shape)
737    self.assertAllClose(hess_value, hess_actual.reshape((m * n, m * n)))
738
739
740class IndexedSlicesToTensorTest(test_util.TensorFlowTestCase):
741
742  @test_util.run_v1_only("b/120545219")
743  def testIndexedSlicesToTensor(self):
744    with self.cached_session():
745      np_val = np.random.rand(4, 4, 4, 4).astype(np.float32)
746      c = constant_op.constant(np_val)
747      c_sparse = math_ops._as_indexed_slices(c)
748      self.assertAllEqual(np_val.shape, c_sparse.dense_shape.eval())
749      c_dense = math_ops.multiply(c_sparse, 1.0)
750      self.assertAllClose(np_val, self.evaluate(c_dense))
751
752  @test_util.run_v1_only("b/120545219")
753  def testIndexedSlicesToTensorList(self):
754    with self.cached_session():
755      numpy_list = []
756      dense_list = []
757      sparse_list = []
758      for _ in range(3):
759        np_val = np.random.rand(4, 4, 4, 4).astype(np.float32)
760        c = constant_op.constant(np_val)
761        c_sparse = math_ops._as_indexed_slices(c)
762        numpy_list.append(np_val)
763        dense_list.append(c)
764        sparse_list.append(c_sparse)
765      packed_dense = array_ops.stack(dense_list)
766      packed_sparse = array_ops.stack(sparse_list)
767      self.assertAllClose(packed_dense.eval(), self.evaluate(packed_sparse))
768
769  @test_util.run_v1_only("b/120545219")
770  def testInt64Indices(self):
771    with self.cached_session():
772      np_val = np.random.rand(4, 4, 4, 4).astype(np.float32)
773      c = constant_op.constant(np_val)
774      c_sparse = math_ops._as_indexed_slices(c)
775      c_sparse = ops.IndexedSlices(
776          c_sparse.values,
777          math_ops.cast(c_sparse.indices, dtypes.int64), c_sparse.dense_shape)
778      self.assertAllEqual(np_val.shape, c_sparse.dense_shape.eval())
779      c_dense = math_ops.multiply(c_sparse, 1.0)
780      self.assertAllClose(np_val, self.evaluate(c_dense))
781
782  @test_util.run_v1_only("b/120545219")
783  def testWarnings(self):
784    # TODO(gunan) Reenable after this issue is fixed:
785    # https://github.com/google/protobuf/issues/2812
786    if sys.version_info >= (3, 5):
787      self.skipTest("Skipped test for Python 3.5+")
788
789    # Smaller than the threshold: no warning.
790    c_sparse = ops.IndexedSlices(
791        array_ops.placeholder(dtypes.float32),
792        array_ops.placeholder(dtypes.int32), constant([4, 4, 4, 4]))
793    with warnings.catch_warnings(record=True) as w:
794      math_ops.multiply(c_sparse, 1.0)
795    self.assertEqual(0, len(w))
796
797    # Greater than or equal to the threshold: warning.
798    c_sparse = ops.IndexedSlices(
799        array_ops.placeholder(dtypes.float32),
800        array_ops.placeholder(dtypes.int32), constant([100, 100, 100, 100]))
801    # "always" filter prevents the warning from being suppressed if it was
802    # already triggered in a different test.
803    warnings.simplefilter("always")
804    with warnings.catch_warnings(record=True) as w:
805      math_ops.multiply(c_sparse, 1.0)
806    self.assertEqual(1, len(w))
807    self.assertTrue(
808        "with 100000000 elements. This may consume a large amount of memory." in
809        str(w[0].message))
810
811    # Unknown dense shape: warning.
812    c_sparse = ops.IndexedSlices(
813        array_ops.placeholder(dtypes.float32),
814        array_ops.placeholder(dtypes.int32),
815        array_ops.placeholder(dtypes.int32))
816    with warnings.catch_warnings(record=True) as w:
817      math_ops.multiply(c_sparse, 1.0)
818    self.assertEqual(1, len(w))
819    self.assertTrue(
820        "of unknown shape. This may consume a large amount of memory." in
821        str(w[0].message))
822
823
824class OnlyRealGradientsTest(test_util.TensorFlowTestCase):
825
826  @test_util.run_v1_only("b/120545219")
827  def testRealOnly(self):
828    x = constant_op.constant(7+3j, dtype=dtypes.complex64)
829    y = math_ops.square(x)
830    with self.assertRaisesRegexp(
831        TypeError,
832        r"Gradients of complex tensors must set grad_ys "
833        r"\(y\.dtype = tf\.complex64\)"):
834      gradients.gradients(y, x)
835
836
837class ResourceCondTest(test_util.TensorFlowTestCase):
838
839  @test_util.run_v1_only("b/120545219")
840  def testBasic(self):
841    gamma = resource_variable_ops.ResourceVariable(
842        np.random.random((3,)),
843        dtype="float32", name="gamma")
844
845    inputs = array_ops.ones(shape=(3,), dtype="float32")
846
847    def TestFn():
848      output = inputs + gamma
849      return output
850
851    training = array_ops.placeholder_with_default(True, shape=())
852    output = control_flow_ops.cond(
853        training, TestFn, lambda: inputs)
854
855    loss = output
856
857    grads = gradients.gradients(
858        loss, [gamma])
859    self.assertTrue(None not in grads)
860
861
862class CustomGradientTest(test_util.TensorFlowTestCase):
863
864  def testCustomGradientTrivial(self):
865
866    @custom_gradient.custom_gradient
867    def MyIdentity(x):
868
869      def Grad(dy):
870        return [3 * dy]
871
872      return x, Grad
873
874    with ops.Graph().as_default():
875      x = constant(3.)
876      y = MyIdentity(MyIdentity(x))
877      dy = gradients.gradients(y, x)[0]
878      with session.Session():
879        self.assertEqual(9., self.evaluate(dy))
880
881  def testCustomGradient(self):
882
883    @custom_gradient.custom_gradient
884    def MyMultiply(x1, x2):
885      result = x1 * x2
886
887      def Grad(dy):
888        # Switched the ordering here.
889        return [dy * x1, dy * x2]
890
891      return result, Grad
892
893    with ops.Graph().as_default():
894      x1 = constant(3.)
895      x2 = constant(5.)
896      y = MyMultiply(x1, x2)
897      dy = gradients.gradients(y, [x1, x2])
898      with session.Session() as sess:
899        self.assertAllEqual([3., 5.], self.evaluate(dy))
900
901  def testCustomGradientErrors(self):
902
903    @custom_gradient.custom_gradient
904    def F(x):
905
906      def Grad(_):
907        raise RuntimeError("x")
908
909      return x, Grad
910
911    with ops.Graph().as_default():
912      x = constant(1.0)
913      y = F(x)
914      with self.assertRaises(RuntimeError):
915        gradients.gradients(y, x)
916
917  def testCustomGradientWithVariables(self):
918
919    @custom_gradient.custom_gradient
920    def F(x):
921      out = core_layers.dense(x, 3, use_bias=False)
922
923      def Grad(out_grad, variables=None):  # pylint: disable=redefined-outer-name
924        self.assertEqual(1, len(variables))
925        grads = gradients.gradients(out, [x, variables[0]], grad_ys=out_grad)
926        return grads[0], [array_ops.ones((4, 3))]
927
928      return out, Grad
929
930    with ops.Graph().as_default():
931      x = array_ops.ones((2, 4))
932      with variable_scope.variable_scope("f", use_resource=True) as vs:
933        y = F(x)
934        all_vars = vs.global_variables()
935        assert len(all_vars) == 1
936      grads = gradients.gradients(y, [x, all_vars[0]])
937      for g in grads:
938        self.assertTrue(g is not None)
939      with session.Session() as sess:
940        self.evaluate(variables.global_variables_initializer())
941        dw = sess.run(math_ops.reduce_sum(grads[1]))
942        self.assertEqual(12., dw)
943
944  def testCustomGradientWithVariablesEager(self):
945    with context.eager_mode():
946      layer = core_layers.Dense(4, use_bias=False)
947
948      @custom_gradient.custom_gradient
949      def F(x):
950        out = layer(x)
951
952        def Grad(out_grad, variables=None):  # pylint: disable=redefined-outer-name
953          del out_grad
954          self.assertEqual(1, len(variables))
955          return (array_ops.ones((3, 2)),
956                  [array_ops.ones((2, 4))])
957
958        return out, Grad
959
960      x = array_ops.ones((3, 2)) + 2.
961      with backprop.GradientTape() as tape:
962        tape.watch(x)
963        y = F(x)
964      w, = layer.variables
965      dx, dw = tape.gradient(y, [x, w])
966      self.assertEqual(6., math_ops.reduce_sum(dx).numpy())
967      self.assertEqual(8., math_ops.reduce_sum(dw).numpy())
968
969  @test_util.run_v1_only("b/120545219")
970  def testCustomGradientErrorsWithNonResourceVariables(self):
971
972    def F(x, use_resource=False):
973      with variable_scope.variable_scope("f", use_resource=use_resource):
974        out = core_layers.dense(x, 4, use_bias=False)
975
976      def Grad(out_grad, variables=None):  # pylint: disable=redefined-outer-name
977        del out_grad
978        self.assertEqual(1, len(variables))
979        return (array_ops.ones((3, 2)), [array_ops.ones((2, 4))])
980
981      return out, Grad
982
983    @custom_gradient.custom_gradient
984    def FResource(x):
985      return F(x, use_resource=True)
986
987    @custom_gradient.custom_gradient
988    def FNonResource(x):
989      return F(x, use_resource=False)
990
991    x = array_ops.ones((3, 2)) + 2.
992
993    # Wrapping scope has use_resource=True but inner scope sets to False. Fails.
994    with variable_scope.variable_scope("vs1", use_resource=True):
995      with self.assertRaisesWithPredicateMatch(TypeError,
996                                               "must be `ResourceVariable`s"):
997        FNonResource(x)
998
999    # Wrapping scope has use_resource=False but inner scope sets to True.
1000    # Passes.
1001    with variable_scope.variable_scope("vs2", use_resource=False):
1002      FResource(x)
1003
1004  def testWithNumpyInputs(self):
1005    with context.eager_mode():
1006
1007      @custom_gradient.custom_gradient
1008      def F(x):
1009        out = x
1010
1011        def Grad(_):
1012          return (None, None)
1013
1014        return out, Grad
1015
1016      x = np.ones((3, 2), dtype=np.float32)
1017      # Smoke test to ensure numpy inputs are accepted
1018      F(x)
1019
1020  @test_util.run_v1_only("b/120545219")
1021  def testRVGradientsDynamicCond(self):
1022    with self.cached_session():
1023      alpha = resource_variable_ops.ResourceVariable(
1024          np.random.random((1,)),
1025          dtype="float32")
1026
1027      conditional = array_ops.placeholder_with_default(True, shape=())
1028      output = control_flow_ops.cond(
1029          conditional, lambda: alpha * 2, lambda: alpha * 3)
1030
1031      g, = gradients_impl.gradients(output, alpha)
1032      self.evaluate(variables.global_variables_initializer())
1033      self.assertAllEqual(g.eval(), [2.0])
1034      self.assertAllEqual(g.eval(feed_dict={conditional: False}), [3.0])
1035
1036  def testRecursiveCustomGradient(self):
1037    @custom_gradient.custom_gradient
1038    def F(x):
1039      out = core_layers.dense(x, 3, use_bias=False)
1040
1041      def Grad(out_grad, variables=None):  # pylint: disable=redefined-outer-name
1042        self.assertEqual(1, len(variables))
1043        grads = gradients.gradients(out, [x, variables[0]], grad_ys=out_grad)
1044        return grads[0], [array_ops.ones((4, 3))]
1045
1046      return out, Grad
1047
1048    @custom_gradient.custom_gradient
1049    def DoubleF(x):
1050      out = F(x)
1051
1052      def Grad(out_grad, variables=None):  # pylint: disable=redefined-outer-name
1053        self.assertEqual(1, len(variables))
1054        grads = gradients.gradients(out, [x, variables[0]], grad_ys=out_grad)
1055        return grads[0], [array_ops.ones((4, 3))]
1056
1057      return out, Grad
1058    with ops.Graph().as_default():
1059      x = array_ops.ones((2, 4))
1060      with variable_scope.variable_scope("f", use_resource=True) as vs:
1061        y = DoubleF(x)
1062        all_vars = vs.global_variables()
1063        assert len(all_vars) == 1
1064      grads = gradients.gradients(y, [x, all_vars[0]])
1065      for g in grads:
1066        self.assertIsNotNone(g)
1067      with session.Session() as sess:
1068        self.evaluate(variables.global_variables_initializer())
1069        dw = sess.run(math_ops.reduce_sum(grads[1]))
1070        self.assertEqual(12., dw)
1071
1072
1073class AggregateIndexedSlicesGradientsTest(test_util.TensorFlowTestCase):
1074
1075  def _assert_indexed_slices_equal(self, left, right):
1076    self.assertAllEqual(
1077        self.evaluate(ops.convert_to_tensor(left)),
1078        self.evaluate(ops.convert_to_tensor(right)))
1079
1080  def testNoGradients(self):
1081    self.assertIsNone(gradients_util._AggregateIndexedSlicesGradients([]))
1082
1083  def testOneGradient(self):
1084    t = math_ops._as_indexed_slices(constant_op.constant(
1085        [[1., 2.], [0, 0], [3., 4.]]))
1086    result = gradients_util._AggregateIndexedSlicesGradients([t])
1087    self._assert_indexed_slices_equal(t, result)
1088
1089  def testMultipleGradients(self):
1090    t0 = math_ops._as_indexed_slices(constant_op.constant(
1091        [[1., 2.], [0, 0], [3., 4.]]))
1092    t1 = math_ops._as_indexed_slices(constant_op.constant(
1093        [[0., 0.], [5, 6], [7., 8.]]))
1094    total = constant_op.constant(
1095        [[1., 2.], [5, 6], [10., 12.]])
1096    result = gradients_util._AggregateIndexedSlicesGradients([t0, t1])
1097    self._assert_indexed_slices_equal(total, result)
1098
1099  def testMultipleGradientsWithNones(self):
1100    t0 = math_ops._as_indexed_slices(constant_op.constant(
1101        [[1., 2.], [0, 0], [3., 4.]]))
1102    t1 = math_ops._as_indexed_slices(constant_op.constant(
1103        [[0., 0.], [5, 6], [7., 8.]]))
1104    t3 = None
1105    total = constant_op.constant(
1106        [[1., 2.], [5, 6], [10., 12.]])
1107    result = gradients_util._AggregateIndexedSlicesGradients([t0, t1, t3])
1108    self._assert_indexed_slices_equal(total, result)
1109
1110  def testMixedTensorAndIndexedSlices(self):
1111    t0 = math_ops._as_indexed_slices(constant_op.constant(
1112        [[1., 2.], [0, 0], [3., 4.]]))
1113    t1 = constant_op.constant(
1114        [[0., 0.], [5, 6], [7., 8.]])
1115    total = constant_op.constant(
1116        [[1., 2.], [5, 6], [10., 12.]])
1117    result = gradients_util._AggregateIndexedSlicesGradients([t0, t1])
1118    self._assert_indexed_slices_equal(total, result)
1119
1120
1121class TensorListGradientsTest(test_util.TensorFlowTestCase):
1122
1123  def testDefaultGradYs(self):
1124    with ops.Graph().as_default():
1125      tl = list_ops.empty_tensor_list(
1126          element_dtype=dtypes.float32,
1127          element_shape=ops.convert_to_tensor([], dtype=dtypes.int32))
1128      a = constant(1.0)
1129      tl = list_ops.tensor_list_push_back(tl, a)
1130
1131      grad_tl = list_ops.empty_tensor_list(
1132          element_dtype=dtypes.float32,
1133          element_shape=ops.convert_to_tensor([], dtype=dtypes.int32))
1134      grad_tl = list_ops.tensor_list_push_back(tl, constant(5.0))
1135
1136      grad = gradients.gradients(tl, a, grad_ys=grad_tl)[0]
1137      with self.cached_session() as sess:
1138        self.assertEquals(self.evaluate(grad), 5.)
1139
1140
1141if __name__ == "__main__":
1142  googletest.main()
1143