# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for tensorflow.ops.gradients."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import sys
import warnings

from absl.testing import parameterized
import numpy as np
from tensorflow.python.client import session
from tensorflow.python.eager import backprop
from tensorflow.python.eager import context
from tensorflow.python.eager import function
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import function as framework_function
from tensorflow.python.framework import ops
from tensorflow.python.framework import test_ops
from tensorflow.python.framework import test_util
from tensorflow.python.framework.constant_op import constant
from tensorflow.python.keras.engine import training
from tensorflow.python.layers import core as core_layers
from tensorflow.python.ops import array_grad  # pylint: disable=unused-import
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_grad  # pylint: disable=unused-import
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import custom_gradient
from tensorflow.python.ops import data_flow_grad  # pylint: disable=unused-import
from tensorflow.python.ops import data_flow_ops  # pylint: disable=unused-import
from tensorflow.python.ops import functional_ops  # pylint: disable=unused-import
from tensorflow.python.ops import gradients
from tensorflow.python.ops import gradients_impl
from tensorflow.python.ops import init_ops
from tensorflow.python.ops import list_ops
from tensorflow.python.ops import math_grad  # pylint: disable=unused-import
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nn_grad  # pylint: disable=unused-import
from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.ops import state_grad  # pylint: disable=unused-import
from tensorflow.python.ops import state_ops
from tensorflow.python.ops import tensor_array_grad  # pylint: disable=unused-import
from tensorflow.python.ops import tensor_array_ops
from tensorflow.python.ops import unconnected_gradients
from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variables
from tensorflow.python.ops.nn_ops import bias_add
from tensorflow.python.platform import googletest


class GradientsTest(test_util.TensorFlowTestCase, parameterized.TestCase):

  def testGradients(self):
    with ops.Graph().as_default():
      inp = constant(1.0, shape=[32, 100], name="in")
      w = constant(1.0, shape=[100, 10], name="w")
      b = constant(1.0, shape=[10], name="b")
      xw = math_ops.matmul(inp, w, name="xw")
      h = bias_add(xw, b, name="h")
      w_grad = gradients.gradients(h, w)[0]
    self.assertEquals("MatMul", w_grad.op.type)
    self.assertEquals(w_grad.op._original_op, xw.op)
    self.assertTrue(w_grad.op.get_attr("transpose_a"))
    self.assertFalse(w_grad.op.get_attr("transpose_b"))

  def testUnusedOutput(self):
    with ops.Graph().as_default():
      w = constant(1.0, shape=[2, 2])
      x = constant(1.0, shape=[2, 2])
      wx = math_ops.matmul(w, x)
      split_wx = array_ops.split(value=wx, num_or_size_splits=2, axis=0)
      c = math_ops.reduce_sum(split_wx[1])
      gw = gradients.gradients(c, [w])[0]
    self.assertEquals("MatMul", gw.op.type)

  def testColocateGradients(self):
    with ops.Graph().as_default() as g:
      w = constant(1.0, shape=[1, 1])
      x = constant(1.0, shape=[1, 2])
      with g.device("/device:GPU:0"):
        wx = math_ops.matmul(w, x)
      gw = gradients.gradients(wx, [w], colocate_gradients_with_ops=True)[0]
    self.assertEqual(gw.op.colocation_groups(), wx.op.colocation_groups())

  def testColocateGradientsWithAggregation(self):
    with ops.Graph().as_default() as g:
      with g.device("/device:GPU:1"):
        w = constant(1.0, shape=[1, 1])
      x = constant(1.0, shape=[1, 2])
      y = constant(1.0, shape=[1, 2])
      wx = math_ops.matmul(w, x)
      wy = math_ops.matmul(w, y)
      with g.device("/device:GPU:0"):
        z = wx + wy

      gw1 = gradients.gradients(z, [w], colocate_gradients_with_ops=True)[0]
      self.assertEqual(gw1.op.colocation_groups(), wx.op.colocation_groups())

      gw2 = gradients.gradients(z, [w], colocate_gradients_with_ops=False)[0]
      self.assertTrue(wx.op.colocation_groups() != gw2.op.colocation_groups())

  def testColocateGradientsWithAggregationInMultipleDevices(self):
    with ops.Graph().as_default() as g:
      with g.device("/device:GPU:1"):
        w = constant(1.0, shape=[1, 1])
      x = constant(1.0, shape=[1, 2])
      y = constant(1.0, shape=[1, 2])
      with g.device("/task:1"):
        wx = math_ops.matmul(w, x)
      with g.device("/task:2"):
        wy = math_ops.matmul(w, y)
      with g.device("/device:GPU:0"):
        z = wx + wy

      gw1 = gradients.gradients(z, [w], colocate_gradients_with_ops=True)[0]
      self.assertEqual(gw1.op.colocation_groups(), w.op.colocation_groups())

      gw2 = gradients.gradients(z, [w], colocate_gradients_with_ops=False)[0]
      self.assertTrue(w.op.colocation_groups() != gw2.op.colocation_groups())

  def testColocateGradientsWithGateGradients(self):
    if not test_util.is_gpu_available():
      self.skipTest("No GPU available")
    with ops.Graph().as_default() as g:
      with g.device("/device:CPU:0"):
        x = constant(1.0, shape=[1, 1])
        y = constant(1.0, shape=[1, 1])
        s = x + y
      with g.device("/device:GPU:0"):
        z = math_ops.reduce_sum(s)

      gz_x = gradients.gradients(z, [x], colocate_gradients_with_ops=True,
                                 gate_gradients=True)[0]
      with session.Session():
        # Make sure the placer doesn't complain.
        self.evaluate(gz_x)

  def testBoundaryStop(self):
    # Test that we don't differentiate 'x'. The gradient function for 'x' is
    # set explicitly to None so we will get an exception if the gradient code
    # tries to differentiate 'x'.
    with ops.Graph().as_default():
      c = constant(1.0)
      x = array_ops.identity(c)
      y = x + 1.0
      z = y + 1
      grads = gradients.gradients(z, [x])
      self.assertTrue(all(x is not None for x in grads))

  @test_util.run_v1_only("b/120545219")
  def testBoundaryContinue(self):
    # Test that we differentiate both 'x' and 'y' correctly when x is a
    # predecessor of y.
    with self.cached_session():
      x = constant(1.0)
      y = x * 2.0
      z = y * 3.0
      grads = gradients.gradients(z, [x, y])
      self.assertTrue(all(x is not None for x in grads))
      self.assertEqual(6.0, grads[0].eval())

  @test_util.run_v1_only("b/120545219")
  def testAggregationMethodAccumulateN(self):
    with self.cached_session():
      x = constant(1.0)
      y = x * 2.0
      z = y + y + y + y + y + y + y + y + y + y
      grads = gradients.gradients(
          z, [x, y],
          aggregation_method=gradients.AggregationMethod.
          EXPERIMENTAL_ACCUMULATE_N)
      self.assertTrue(all(x is not None for x in grads))
      self.assertEqual(20.0, grads[0].eval())
      self.assertEqual(10.0, grads[1].eval())

  @test_util.run_v1_only("b/120545219")
  def testAggregationMethodAddN(self):
    with self.cached_session():
      x = constant(1.0)
      y = x * 2.0
      z = y + y + y + y + y + y + y + y + y + y
      grads = gradients.gradients(
          z, [x, y], aggregation_method=gradients.AggregationMethod.ADD_N)
      self.assertTrue(all(x is not None for x in grads))
      self.assertEqual(20.0, grads[0].eval())
      self.assertEqual(10.0, grads[1].eval())

  @test_util.run_v1_only("b/120545219")
  def testAggregationMethodTree(self):
    with self.cached_session():
      x = constant(1.0)
      y = x * 2.0
      z = y + y + y + y + y + y + y + y + y + y
      grads = gradients.gradients(
          z, [x, y],
          aggregation_method=gradients.AggregationMethod.EXPERIMENTAL_TREE)
      self.assertTrue(all(x is not None for x in grads))
      self.assertEqual(20.0, grads[0].eval())
      self.assertEqual(10.0, grads[1].eval())

  def testNoGradientForStringOutputs(self):
    with ops.Graph().as_default():

      def _TestOpGrad(_, float_grad, string_grad):
        """Gradient function for TestStringOutput."""
        self.assertEquals(float_grad.dtype, dtypes.float32)
        self.assertFalse(string_grad)
        return float_grad

      ops.RegisterGradient("TestStringOutput")(_TestOpGrad)

      c = constant(1.0)
      x, _ = test_ops.test_string_output(c)
      z = x * 2.0
      w = z * 3.0
      grads = gradients.gradients(z, [c])
      self.assertIsInstance(grads[0], ops.Tensor)
      grads = gradients.gradients(w, [c])
      self.assertIsInstance(grads[0], ops.Tensor)

  def testNoGradientForStringOutputsWithOpNamespace(self):
    with ops.Graph().as_default():

      def _TestOpGrad(_, float_grad, string_grad):
        """Gradient function for TestStringOutput."""
        self.assertEqual(float_grad.dtype, dtypes.float32)
        self.assertFalse(string_grad)
        return float_grad

      ops.RegisterGradient("Namespace>TestStringOutput")(_TestOpGrad)

      c = constant(1.0)
      x, _ = test_ops.namespace_test_string_output(c)
      z = x * 2.0
      w = z * 3.0
      grads = gradients.gradients(z, [c])
      self.assertIsInstance(grads[0], ops.Tensor)
      grads = gradients.gradients(w, [c])
      self.assertIsInstance(grads[0], ops.Tensor)

  def testSingletonIndexedSlices(self):
    with ops.Graph().as_default():
      x = array_ops.placeholder(dtypes.float32)
      y = array_ops.identity(x)
      dy = ops.IndexedSlices(
          array_ops.placeholder(dtypes.float32),
          array_ops.placeholder(dtypes.int32))
      dx, = gradients.gradients(y, x, grad_ys=dy)
      # The IndexedSlices gradient of tf.identity is the identity map.
      with self.cached_session() as sess:
        vdx, vdy = sess.run(
            [dx, dy], feed_dict={x: [1.0], dy.indices: [0], dy.values: [2.0]})
      self.assertEqual(vdx, vdy)

  @test_util.run_v1_only("b/120545219")
  def testNonDifferentiableSwitchInWhileLoop(self):
    with ops.Graph().as_default():
      v = array_ops.placeholder(dtypes.float32, [])

      def _Step(i, a, ta):
        a += math_ops.cast(v, dtypes.int32)
        return (i + 1, a, ta.write(i, a))

      n = 4
      i, _, ta = control_flow_ops.while_loop(
          lambda i, *_: i < n,
          _Step, [0, 0, tensor_array_ops.TensorArray(
              dtypes.int32, size=n)])
      target = ta.read(i - 1)
      grad, = gradients.gradients(target, v)
      self.assertIsNone(grad)

  def testVariableReadValueGradient(self):
    with ops.Graph().as_default():
      init = constant_op.constant(100.0)
      var = variables.Variable(init)
      gradient = gradients.gradients(var.read_value(), var)
      self.assertIsNotNone(gradient)

  @parameterized.parameters(dtypes.float32, dtypes.float64)
  def testVariableDefaultGrad(self, dtype):
    with ops.Graph().as_default():
      init = constant_op.constant(100.0, dtype=dtype)
      var = variables.Variable(init)
      dummy_const = constant_op.constant(0.0)
      gradient = gradients.gradients(
          dummy_const,
          var,
          unconnected_gradients=unconnected_gradients.UnconnectedGradients.ZERO
      )[0]
      self.assertEqual(gradient.dtype, dtype)
      self.assertIsNotNone(gradient)

  def testVariableAsGraphElementGradient(self):
    with ops.Graph().as_default() as graph:
      init = constant_op.constant(100.0)
      var = variables.Variable(init)
      gradient = gradients.gradients(graph.as_graph_element(var), var)
      self.assertIsNotNone(gradient)

  @test_util.run_v1_only("b/120545219")
  def testVariableRefGradient(self):
    with ops.Graph().as_default():
      init = constant_op.constant(100.0)
      var = variables.VariableV1(init)
      gradient = gradients.gradients(var._ref(), var)
      self.assertIsNotNone(gradient)

  @test_util.run_v1_only("b/120545219")
  def testDependentYs(self):
    with self.cached_session():
      x = constant_op.constant(3.0)
      y = math_ops.square(x)
      y1 = math_ops.square(y)
      y2 = math_ops.square(y1)
      g = gradients.gradients([y, y2], x)
      self.assertAllClose(17502.0, g[0].eval())
      g = gradients.gradients(y + y2, x)
      self.assertAllClose(17502.0, g[0].eval())
      z = array_ops.identity(y)
      z2 = array_ops.identity(y2)
      g = gradients.gradients([z, z2], x)
      self.assertAllClose(17502.0, g[0].eval())

  @test_util.run_v1_only("b/120545219")
  def testPartialDerivatives(self):
    with self.cached_session():
      x = constant_op.constant(1.)
      y = 2 * x
      z = x + y
      totalg = gradients.gradients(z, [x, y])
      self.assertEqual([3.0, 1.0], [g.eval() for g in totalg])
      partialg = gradients.gradients(z, [x, y], stop_gradients=[x, y])
      self.assertEqual([1.0, 1.0], [g.eval() for g in partialg])

  @test_util.run_v1_only("b/120545219")
  def testStopGradients(self):
    def _MakeGraph(rng, stop_gradients=()):
      def _FunctionOf(xs, k=3):
        return ops.convert_to_tensor(
            sum(math_ops.matmul(rng.rand(k, k), x) for x in xs)
            + rng.rand(k, k))

      a = _FunctionOf([])
      if "a" in stop_gradients: a = array_ops.stop_gradient(a)
      b = _FunctionOf([a])
      if "b" in stop_gradients: b = array_ops.stop_gradient(b)
      c = _FunctionOf([a, b])
      if "c" in stop_gradients: c = array_ops.stop_gradient(c)
      d = _FunctionOf([b, c])
      if "d" in stop_gradients: d = array_ops.stop_gradient(d)
      return dict(a=a, b=b, c=c, d=d)

    def _Gradients(ys, xs, **kwargs):
      dydxs = gradients.gradients(ys, xs, **kwargs)
      dydxs = [0. * x if dydx is None else dydx
               for x, dydx in zip(xs, dydxs)]
      return dydxs

    seed = np.random.randint(1000)
    cases = []
    subsets = [""] + "a b c d ab ac ad bc bd cd abc abd acd bcd abcd".split()
    graph = _MakeGraph(np.random.RandomState(seed))
    for constants in subsets:
      graph_with_stops = _MakeGraph(np.random.RandomState(seed), constants)
      for variables_ in subsets:
        # compute the gradient when stopped using tf.stop_gradients
        grad1 = _Gradients([graph_with_stops["d"]],
                           [graph_with_stops[v] for v in variables_])
        # compute the gradient when stopped using the stop_gradients kwarg
        grad2 = _Gradients([graph["d"]],
                           [graph[v] for v in variables_],
                           stop_gradients=[graph[v] for v in constants])
        cases.append(dict(grad1=grad1, grad2=grad2,
                          constants=constants, variables=variables_))

    # evaluate all tensors in one call to session.run for speed
    with self.cached_session() as sess:
      results = sess.run([(case["grad1"], case["grad2"]) for case in cases])

    for (npgrad1, npgrad2), case in zip(results, cases):
      for a, b in zip(npgrad1, npgrad2):
        np.testing.assert_allclose(a, b)

  def testUnconnectedGradientsNoneUnconnectedGradients(self):
    with ops.Graph().as_default():
      x = constant(1.0, shape=[2, 2])
      y = constant(3.0, shape=[3, 1])
      grad = gradients.gradients(
          [y], [x], unconnected_gradients="none")
    self.assertIsNone(grad[0])

  def testUnconnectedGradientsZerosUnconnectedGradients(self):
    with ops.Graph().as_default():
      x = constant(1.0, shape=[2, 2])
      y = constant(3.0, shape=[3, 1])
      grads = gradients.gradients(
          [y], [x], unconnected_gradients="zero")
      with self.cached_session() as sess:
        self.assertAllEqual([[0.0, 0.0], [0.0, 0.0]], self.evaluate(grads)[0])

  def testUnconnectedGradientsZeroConnectedGradients(self):
    with ops.Graph().as_default():
      x = constant(1.0)
      y = x * 3.0
      grad = gradients.gradients(
          [y], [x], unconnected_gradients="zero")
      with self.cached_session() as sess:
        self.assertEquals(3.0, self.evaluate(grad)[0])

  def testUnknownUnconnectedGradientsValueGiven(self):
    with ops.Graph().as_default():
      x = constant(1.0)
      y = constant(1.0)
      with self.assertRaisesRegexp(
          ValueError, "Unknown value for unconnected_gradients: 'nonsense'"):
        gradients.gradients([y], [x], unconnected_gradients="nonsense")


class FunctionGradientsTest(test_util.TensorFlowTestCase):

  @classmethod
  def XSquarePlusB(cls, x, b):
    return x * x + b

  @classmethod
  def XSquarePlusBGradient(cls, x, b, g):
    # Perturb gradients (multiply by 2), so we can test that this was called.
    g *= 2.0
    return g * 2.0 * x, g

  @classmethod
  def _PythonGradient(cls, op, grad):
    # Perturb gradients (multiply by 3), so we can test that this was called.
    grad *= 3.0
    return grad * op.inputs[0] * 2.0, grad

  @classmethod
  def _GetFunc(cls, **kwargs):
    return framework_function.Defun(dtypes.float32, dtypes.float32, **
                                    kwargs)(cls.XSquarePlusB)

  def _GetFuncGradients(self, f, x_value, b_value):
    x = constant_op.constant(x_value, name="x")
    b = constant_op.constant(b_value, name="b")

    y = f(x, b)
    grads = gradients.gradients(y, [x, b])
    with self.cached_session() as sess:
      return sess.run(grads)

  def testFunctionGradientsBasic(self):
    g = ops.Graph()
    with g.as_default():
      f = self._GetFunc()
      # Get gradients (should add SymbolicGradient node for function).
      grads = self._GetFuncGradients(f, [2.0], [1.0])
      self.assertAllEqual([4.0], grads[0])
      self.assertAllEqual([1.0], grads[1])

  def testFunctionGradientsComposition(self):
    with ops.Graph().as_default():
      f = self._GetFunc()
      x = constant_op.constant([2.0], name="x")
      b1 = constant_op.constant([1.0], name="b1")
      b2 = constant_op.constant([1.0], name="b2")

      y = f(f(x, b1), b2)
      # Build gradient graph (should add SymbolicGradient node for function).
      grads = gradients.gradients(y, [x, b1])

      with self.cached_session() as sess:
        self.assertAllEqual([40.0], self.evaluate(grads)[0])
        self.assertAllEqual([10.0], self.evaluate(grads)[1])

  def testFunctionGradientsWithGradFunc(self):
    g = ops.Graph()
    with g.as_default():
      grad_func = framework_function.Defun(dtypes.float32, dtypes.float32,
                                           dtypes.float32)(
                                               self.XSquarePlusBGradient)
      f = self._GetFunc(grad_func=grad_func)
      # Get gradients (should add SymbolicGradient node for function, which
      # uses the grad_func above, which multiplies all gradients by 2).
      grads = self._GetFuncGradients(f, [2.0], [1.0])
      self.assertAllEqual([4.0 * 2], grads[0])
      self.assertAllEqual([1.0 * 2], grads[1])

  def testFunctionGradientWithRegistration(self):
    g = ops.Graph()
    with g.as_default():
      f = self._GetFunc(python_grad_func=self._PythonGradient)
      # Get gradients, using the python gradient function. It multiplies the
      # gradients by 3.
      grads = self._GetFuncGradients(f, [2.0], [1.0])
      self.assertAllEqual([4.0 * 3], grads[0])
      self.assertAllEqual([1.0 * 3], grads[1])

  def testFunctionGradientWithGradFuncAndRegistration(self):
    g = ops.Graph()
    with g.as_default():
      grad_func = framework_function.Defun(dtypes.float32, dtypes.float32,
                                           dtypes.float32)(
                                               self.XSquarePlusBGradient)
      with self.assertRaisesRegexp(ValueError, "Gradient defined twice"):
        f = self._GetFunc(
            grad_func=grad_func, python_grad_func=self._PythonGradient)
        f.add_to_graph(ops.Graph())

  def testGradientWrtCaptured(self):
    with ops.Graph().as_default():
      x = constant_op.constant(1.0, name="x")

      @function.defun()
      def Foo():
        y = math_ops.multiply(x, 2.0, name="y")
        g = gradients_impl.gradients(y, x)
        return g[0]

      f = Foo()
      with self.cached_session() as sess:
        self.assertEqual(self.evaluate(f), 2.0)

  def testGradientOfCaptured(self):
    with ops.Graph().as_default():
      x = constant_op.constant(1.0, name="x")
      y = math_ops.multiply(x, 2.0, name="y")

      @framework_function.Defun()
      def Foo():
        g = gradients_impl.gradients(y, x)
        return g[0]

      f = Foo()
      with self.cached_session() as sess:
        self.assertEqual(self.evaluate(f), 2.0)

  def testCapturedResourceVariable(self):
    with ops.Graph().as_default():
      var = resource_variable_ops.ResourceVariable(1.0, name="var")

      @function.defun()
      def Foo():
        y = math_ops.multiply(var, 2.0, name="y")
        g = gradients_impl.gradients(y, var)
        return g[0]

      f = Foo()
      with self.cached_session() as sess:
        self.evaluate(variables.global_variables_initializer())
        self.assertEqual(self.evaluate(f), 2.0)

  def testCapturedNested(self):
    with ops.Graph().as_default():
      x1 = constant_op.constant(1.0, name="x1")
      x2 = constant_op.constant(2.0, name="x2")
      x3 = math_ops.multiply(x1, x2, name="x3")

      @function.defun()
      def Outer():
        outer1 = array_ops.identity(x1, name="outer1")

        @function.defun()
        def Inner():
          inner1 = array_ops.identity(outer1, name="inner1")
          inner2 = array_ops.identity(x2, name="inner2")
          inner3 = array_ops.identity(x3, name="inner3")
          return gradients_impl.gradients([inner1, inner2, inner3, x1],
                                          [x1, x2])

        return Inner()

      x1_grad, x2_grad = Outer()
      with self.cached_session() as sess:
        # 1.0 + None + 2.0 + 1.0 = 4.0
        self.assertEqual(self.evaluate(x1_grad), 4.0)
        # None + 1.0 + 1.0 + None = 2.0
        self.assertEqual(self.evaluate(x2_grad), 2.0)

  def testCapturedFromFunction(self):
    with ops.Graph().as_default():
      x = constant_op.constant(1.0, name="x")

      @function.defun()
      def Outer():
        y = math_ops.multiply(x, 2.0, name="y")

        @function.defun()
        def Inner():
          z = math_ops.multiply(y, 3.0, name="z")
          g = gradients_impl.gradients(z, y)
          return g[0]

        return Inner()

      z_grad = Outer()
      with self.cached_session() as sess:
        self.assertEqual(self.evaluate(z_grad), 3.0)

  def testCapturedEagerTensors(self):
    # Test that we can handle captured eager tensors unrelated to the gradient
    # computation (i.e. we need to ignore them).
    # TODO(skyewm): make it an error if you try to take the gradient wrt a
    # captured EagerTensor
    with context.eager_mode():
      c = constant_op.constant(2.0, name="c")

      @function.defun
      def Foo():
        x = constant_op.constant(10.0, name="x")
        y = math_ops.multiply(x, c, name="y")
        # Regression test for b/122564611.
        z = math_ops.multiply(c, y, name="z")
        g = gradients_impl.gradients(z, x)
        return g[0]

      self.assertEqual(Foo().numpy(), 4.0)


class StopGradientTest(test_util.TensorFlowTestCase):

  def testStopGradient(self):
    with ops.Graph().as_default():
      inp = constant(1.0, shape=[100, 32], name="in")
      out = array_ops.stop_gradient(inp)
      igrad = gradients.gradients(out, inp)[0]
    assert igrad is None


class PreventGradientTest(test_util.TensorFlowTestCase):

  def testPreventGradient(self):
    with ops.Graph().as_default():
      inp = constant(1.0, shape=[100, 32], name="in")
      out = array_ops.prevent_gradient(inp)
      with self.assertRaisesRegexp(LookupError, "explicitly disabled"):
        _ = gradients.gradients(out, inp)


class HessianVectorProductTest(test_util.TensorFlowTestCase):

  @test_util.run_v1_only("b/120545219")
  def testHessianVectorProduct(self):
    # Manually compute the Hessian explicitly for a low-dimensional problem
    # and check that HessianVectorProduct matches multiplication by the
    # explicit Hessian.
    # Specifically, the Hessian of f(x) = x^T A x is
    # H = A + A^T.
    # We expect HessianVectorProduct(f(x), x, v) to be H v.
    m = 4
    rng = np.random.RandomState([1, 2, 3])
    mat_value = rng.randn(m, m).astype("float32")
    v_value = rng.randn(m, 1).astype("float32")
    x_value = rng.randn(m, 1).astype("float32")
    hess_value = mat_value + mat_value.T
    hess_v_value = np.dot(hess_value, v_value)
    for use_gpu in [False, True]:
      with self.cached_session(use_gpu=use_gpu):
        mat = constant_op.constant(mat_value)
        v = constant_op.constant(v_value)
        x = constant_op.constant(x_value)
        mat_x = math_ops.matmul(mat, x, name="Ax")
        x_mat_x = math_ops.matmul(array_ops.transpose(x), mat_x, name="xAx")
        hess_v = gradients_impl._hessian_vector_product(x_mat_x, [x], [v])[0]
        hess_v_actual = self.evaluate(hess_v)
      self.assertAllClose(hess_v_value, hess_v_actual)


class HessianTest(test_util.TensorFlowTestCase):

  @test_util.run_v1_only("b/120545219")
  def testHessian1D(self):
    # Manually compute the Hessian explicitly for a low-dimensional problem
    # and check that `hessian` matches. Specifically, the Hessian of
    # f(x) = x^T A x is H = A + A^T.
    m = 4
    rng = np.random.RandomState([1, 2, 3])
    mat_value = rng.randn(m, m).astype("float32")
    x_value = rng.randn(m).astype("float32")
    hess_value = mat_value + mat_value.T
    with self.session(use_gpu=True):
      mat = constant_op.constant(mat_value)
      x = constant_op.constant(x_value)
      x_mat_x = math_ops.reduce_sum(x[:, None] * mat * x[None, :])
      hess = gradients.hessians(x_mat_x, x)[0]
      hess_actual = self.evaluate(hess)
    self.assertAllClose(hess_value, hess_actual)

  @test_util.run_v1_only("b/120545219")
  def testHessian1D_multi(self):
    # Test the computation of the hessian with respect to multiple tensors
    m = 4
    n = 3
    rng = np.random.RandomState([1, 2, 3])
    mat_values = [rng.randn(m, m).astype("float32") for _ in range(n)]
    x_values = [rng.randn(m).astype("float32") for _ in range(n)]
    hess_values = [mat_value + mat_value.T for mat_value in mat_values]
    with self.session(use_gpu=True):
      mats = [constant_op.constant(mat_value) for mat_value in mat_values]
      xs = [constant_op.constant(x_value) for x_value in x_values]
      xs_mats_xs = [
          math_ops.reduce_sum(x[:, None] * mat * x[None, :])
          for x, mat in zip(xs, mats)
      ]
      hessians = gradients.hessians(xs_mats_xs, xs)
      hessians_actual = [hess.eval() for hess in hessians]
    for hess_value, hess_actual in zip(hess_values, hessians_actual):
      self.assertAllClose(hess_value, hess_actual)

  @test_util.run_v1_only("b/120545219")
  def testHessianInvalidDimension(self):
    for shape in [(10, 10), None]:
      with self.cached_session(use_gpu=True):
        x = array_ops.placeholder(dtypes.float32, shape)
        # Expect a ValueError because the dimensions are wrong
        with self.assertRaises(ValueError):
          gradients.hessians(x, x)

  @test_util.run_v1_only("b/120545219")
  def testHessian2D_square_matrix(self):
    # Manually compute the Hessian explicitly for a low-dimensional problem
    # and check that `hessian` matches. Specifically, the Hessian of
    # f(x) = 1/2 * x^T * x is H = constant (block identity matrix)
    m = 3
    rng = np.random.RandomState([1, 2, 3])
    x_value = rng.randn(m, m).astype("float32")
    with self.session(use_gpu=True):
      x = constant_op.constant(x_value)
      x_square = math_ops.reduce_sum(
          math_ops.matmul(array_ops.transpose(x), x) * 0.5
      )
      hess = gradients.hessians(x_square, x)[0]
      hess_actual = self.evaluate(hess)
    hess_value = np.bmat([
        [elem*np.ones((m, m)) for elem in vec]
        for vec in np.eye(m)
    ]).astype("float32")
    self.assertAllEqual((m, m, m, m), hess_actual.shape)
    self.assertAllClose(hess_value, hess_actual.reshape((m * m, m * m)))

  @test_util.run_v1_only("b/120545219")
  def testHessian2D_non_square_matrix(self):
    m = 3
    n = 4
    rng = np.random.RandomState([1, 2, 3])
    x_value = rng.randn(m, n).astype("float32")
    with self.session(use_gpu=True):
      x = constant_op.constant(x_value)
      x_square = math_ops.reduce_sum(
          math_ops.matmul(array_ops.transpose(x), x) * 0.5
      )
      hess = gradients.hessians(x_square, x)[0]
      hess_actual = self.evaluate(hess)
    hess_value = np.bmat([
        [elem*np.ones((n, n)) for elem in vec]
        for vec in np.eye(m)
    ]).astype("float32")
    self.assertAllEqual((m, n, m, n), hess_actual.shape)
    self.assertAllClose(hess_value, hess_actual.reshape((m * n, m * n)))


class IndexedSlicesToTensorTest(test_util.TensorFlowTestCase):

  @test_util.run_v1_only("b/120545219")
  def testIndexedSlicesToTensor(self):
    with self.cached_session():
      np_val = np.random.rand(4, 4, 4, 4).astype(np.float32)
      c = constant_op.constant(np_val)
      c_sparse = math_ops._as_indexed_slices(c)
      self.assertAllEqual(np_val.shape, c_sparse.dense_shape.eval())
      c_dense = math_ops.multiply(c_sparse, 1.0)
      self.assertAllClose(np_val, self.evaluate(c_dense))

  @test_util.run_v1_only("b/120545219")
  def testIndexedSlicesToTensorList(self):
    with self.cached_session():
      numpy_list = []
      dense_list = []
      sparse_list = []
      for _ in range(3):
        np_val = np.random.rand(4, 4, 4, 4).astype(np.float32)
        c = constant_op.constant(np_val)
        c_sparse = math_ops._as_indexed_slices(c)
        numpy_list.append(np_val)
        dense_list.append(c)
        sparse_list.append(c_sparse)
      packed_dense = array_ops.stack(dense_list)
      packed_sparse = array_ops.stack(sparse_list)
      self.assertAllClose(packed_dense.eval(), self.evaluate(packed_sparse))

  @test_util.run_v1_only("b/120545219")
  def testInt64Indices(self):
    with self.cached_session():
      np_val = np.random.rand(4, 4, 4, 4).astype(np.float32)
      c = constant_op.constant(np_val)
      c_sparse = math_ops._as_indexed_slices(c)
      c_sparse = ops.IndexedSlices(
          c_sparse.values,
          math_ops.cast(c_sparse.indices, dtypes.int64), c_sparse.dense_shape)
      self.assertAllEqual(np_val.shape, c_sparse.dense_shape.eval())
      c_dense = math_ops.multiply(c_sparse, 1.0)
      self.assertAllClose(np_val, self.evaluate(c_dense))

  @test_util.run_v1_only("b/120545219")
  def testWarnings(self):
    # TODO(gunan) Reenable after this issue is fixed:
    # https://github.com/google/protobuf/issues/2812
    if sys.version_info >= (3, 5):
      self.skipTest("Skipped test for Python 3.5+")

    # Smaller than the threshold: no warning.
    c_sparse = ops.IndexedSlices(
        array_ops.placeholder(dtypes.float32),
        array_ops.placeholder(dtypes.int32), constant([4, 4, 4, 4]))
    with warnings.catch_warnings(record=True) as w:
      math_ops.multiply(c_sparse, 1.0)
    self.assertEqual(0, len(w))

    # Greater than or equal to the threshold: warning.
    c_sparse = ops.IndexedSlices(
        array_ops.placeholder(dtypes.float32),
        array_ops.placeholder(dtypes.int32), constant([100, 100, 100, 100]))
    # "always" filter prevents the warning from being suppressed if it was
    # already triggered in a different test.
    warnings.simplefilter("always")
    with warnings.catch_warnings(record=True) as w:
      math_ops.multiply(c_sparse, 1.0)
    self.assertEqual(1, len(w))
    self.assertTrue(
        "with 100000000 elements. This may consume a large amount of memory." in
        str(w[0].message))

    # Unknown dense shape: warning.
    c_sparse = ops.IndexedSlices(
        array_ops.placeholder(dtypes.float32),
        array_ops.placeholder(dtypes.int32),
        array_ops.placeholder(dtypes.int32))
    with warnings.catch_warnings(record=True) as w:
      math_ops.multiply(c_sparse, 1.0)
    self.assertEqual(1, len(w))
    self.assertTrue(
        "of unknown shape. This may consume a large amount of memory." in
        str(w[0].message))


class OnlyRealGradientsTest(test_util.TensorFlowTestCase):

  @test_util.run_v1_only("b/120545219")
  def testRealOnly(self):
    x = constant_op.constant(7+3j, dtype=dtypes.complex64)
    y = math_ops.square(x)
    with self.assertRaisesRegexp(
        TypeError,
        r"Gradients of complex tensors must set grad_ys "
        r"\(y\.dtype = tf\.complex64\)"):
      gradients.gradients(y, x)


class ResourceCondTest(test_util.TensorFlowTestCase):

  @test_util.run_v1_only("b/120545219")
  def testBasic(self):
    gamma = resource_variable_ops.ResourceVariable(
        np.random.random((3,)),
        dtype="float32", name="gamma")

    inputs = array_ops.ones(shape=(3,), dtype="float32")

    def TestFn():
      output = inputs + gamma
      return output

    training = array_ops.placeholder_with_default(True, shape=())
    output = control_flow_ops.cond(
        training, TestFn, lambda: inputs)

    loss = output

    grads = gradients.gradients(
        loss, [gamma])
    self.assertTrue(None not in grads)


class GetDependentVariablesTest(test_util.TensorFlowTestCase):

  def testNoVariables(self):
    with ops.Graph().as_default():
      func = lambda x: array_ops.identity(x) + 5.0
      input_t = constant_op.constant(2.0)
      result_t = func(input_t)
      dependent_vars = custom_gradient.get_dependent_variables(
          [input_t], [result_t])

      # There are no variables.
      self.assertEqual(dependent_vars, [])

  def testVariablesOutside(self):
    with ops.Graph().as_default():
      init = constant_op.constant(100.0)
      var = variables.Variable(init)

      # The variable is closed over. It should be found.
      func = lambda x: array_ops.identity(x) + 5.0 + var

      input_t = constant_op.constant(2.0)
      result_t = func(input_t)
      dependent_vars = custom_gradient.get_dependent_variables(
          [input_t], [result_t])
      self.assertEqual(dependent_vars, [var])

  def testVariableSamePrefix(self):
    with ops.Graph().as_default():
      var_name = "my_variable"
      v_z = variable_scope.get_variable(var_name, shape=())
      v_o = variable_scope.get_variable(var_name + "_ones", shape=())

      # The variable is closed over. It should be found.
      func = lambda x: array_ops.identity(x) + 5.0 + v_z + v_o

      input_t = constant_op.constant(2.0)
      result_t = func(input_t)
      dependent_vars = custom_gradient.get_dependent_variables(
          [input_t], [result_t])
      self.assertEqual(set(dependent_vars), set([v_o, v_z]))

  def testVariablesOutsideButDSeparated(self):
    with ops.Graph().as_default():
      init = constant_op.constant(100.0)
      var = variables.Variable(init)

      # The variable is d-separated by the inputs. It should not be found.
      input_t = array_ops.identity(var) * 5.0

      func = lambda x: array_ops.identity(x) + 5.0
      result_t = func(input_t)
      dependent_vars = custom_gradient.get_dependent_variables(
          [input_t], [result_t])
      self.assertEqual(dependent_vars, [])

  def testVariablesOutsideAndNonDifferentiable(self):
    with ops.Graph().as_default():
      init = constant_op.constant(100.0, shape=(5,))
      var = variables.Variable(init, shape=(5,))

      def _Func(x):
        # non-differentiable dependency on var.
        # the variable should not be found.
        y = array_ops.ones_like(var)
        return array_ops.identity(x) + 5.0 + y

      input_t = constant_op.constant(2.0)
      result_t = _Func(input_t)
      dependent_vars = custom_gradient.get_dependent_variables(
          [input_t], [result_t])
      self.assertEqual(dependent_vars, [])

  def testVariablesOutsideAndNonTrainable(self):
    with ops.Graph().as_default():
      init = constant_op.constant(100.0, shape=(5,))

      # Both variables are used in the function but only the trainable one
      # should be found.
      var_trainable = variables.Variable(init, shape=(5,))
      var_nontrainable = variables.Variable(init, shape=(5,), trainable=False)

      def _Func(x):
        del x
        return var_trainable + var_nontrainable

      input_t = constant_op.constant(2.0)
      result_t = _Func(input_t)
      dependent_vars = custom_gradient.get_dependent_variables(
          [input_t], [result_t])
      self.assertEqual(dependent_vars, [var_trainable])

  def testNesting(self):
    with ops.Graph().as_default():
      init = constant_op.constant(100.0, shape=(5,))
      var = variables.Variable(init, shape=(5,))

      def _Func(inputs):
        x = inputs["x"]
        result = array_ops.identity(x) + 5.0 + var
        return {
            "y": result
        }

      input_t = constant_op.constant(2.0)
      func_inputs = {
          "x": input_t
      }
      result_t = _Func(func_inputs)

      # Ensure we can deal with dictionary input and output.
      dependent_vars = custom_gradient.get_dependent_variables(
          func_inputs, result_t)
      self.assertEqual(dependent_vars, [var])

  def testVariablesOutsideAndCustomGradient(self):
    with ops.Graph().as_default():
      init = constant_op.constant(100.0, shape=(5,))
      var = variables.Variable(init, shape=(5,))

      @custom_gradient.custom_gradient
      def _MyOnesLike(x):
        """Dummy version of ones_like which defines a gradient."""

        output = array_ops.ones_like(x)

        def _Grad(dy):
          return array_ops.identity(dy)

        return output, _Grad

      def _Func(x):
        # non-differentiable operation with custom gradient.
        # The variable should be found.
        y = _MyOnesLike(var)
        return array_ops.identity(x) + 5.0 + y

      input_t = constant_op.constant(2.0)
      result_t = _Func(input_t)
      dependent_vars = custom_gradient.get_dependent_variables(
          [input_t], [result_t])
      self.assertEqual(dependent_vars, [var])


class CustomGradientTest(test_util.TensorFlowTestCase):

  def testCustomGradientTrivial(self):

    @custom_gradient.custom_gradient
    def MyIdentity(x):

      def Grad(dy):
        return [3 * dy]

      return x, Grad

    with ops.Graph().as_default():
      x = constant(3.)
      y = MyIdentity(MyIdentity(x))
      dy = gradients.gradients(y, x)[0]
      with session.Session():
        self.assertEqual(9., self.evaluate(dy))

  def testCustomGradient(self):

    @custom_gradient.custom_gradient
    def MyMultiply(x1, x2):
      result = x1 * x2

      def Grad(dy):
        # Switched the ordering here.
        return [dy * x1, dy * x2]

      return result, Grad

    with ops.Graph().as_default():
      x1 = constant(3.)
      x2 = constant(5.)
      y = MyMultiply(x1, x2)
      dy = gradients.gradients(y, [x1, x2])
      with session.Session() as sess:
        self.assertAllEqual([3., 5.], self.evaluate(dy))

  def testCustomGradientClass(self):

    class Model(object):

      @custom_gradient.custom_gradient
      def Multiply(self, x1, x2):
        result = x1 * x2
        grad = lambda dy: (dy * x1, dy * x2)
        return result, grad

    with ops.Graph().as_default():
      x1 = constant(3.)
      x2 = constant(5.)
      m = Model()
      y = m.Multiply(x1, x2)
      dy = gradients.gradients(y, [x1, x2])
      self.assertAllEqual([3., 5.], self.evaluate(dy))

  def testCustomGradientErrors(self):

    @custom_gradient.custom_gradient
    def F(x):

      def Grad(_):
        raise RuntimeError("x")

      return x, Grad

    with ops.Graph().as_default():
      x = constant(1.0)
      y = F(x)
      with self.assertRaises(RuntimeError):
        gradients.gradients(y, x)

  def testCustomGradientWithVariables(self):

    @custom_gradient.custom_gradient
    def F(x):
      out = core_layers.dense(x, 3, use_bias=False)

      def Grad(out_grad, variables=None):  # pylint: disable=redefined-outer-name
        self.assertEqual(1, len(variables))
        grads = gradients.gradients(out, [x, variables[0]], grad_ys=out_grad)
        return grads[0], [array_ops.ones((4, 3))]

      return out, Grad

    with ops.Graph().as_default():
      x = array_ops.ones((2, 4))
      with variable_scope.variable_scope("f", use_resource=True) as vs:
        y = F(x)
        all_vars = vs.global_variables()
        assert len(all_vars) == 1
      grads = gradients.gradients(y, [x, all_vars[0]])
      for g in grads:
        self.assertTrue(g is not None)
      with session.Session() as sess:
        self.evaluate(variables.global_variables_initializer())
        dw = sess.run(math_ops.reduce_sum(grads[1]))
        self.assertEqual(12., dw)

  def testCustomGradientWithVariablesNoFalsePositives(self):

    @custom_gradient.custom_gradient
    def F(x):
      out = core_layers.dense(x, 3, use_bias=False)

      def Grad(out_grad, variables=None):  # pylint: disable=redefined-outer-name
        self.assertEqual(1, len(variables))
        grads = gradients.gradients(out, [x, variables[0]], grad_ys=out_grad)
        return grads[0], [array_ops.ones((3, 3))]

      return out, Grad

    with ops.Graph().as_default():
      with variable_scope.variable_scope("f", use_resource=True) as vs:
        a = array_ops.ones((2, 4))

        # Variabes in these layers shouldn't be picked up by the decorator.
        b = core_layers.dense(a, 3, use_bias=False)
        c = core_layers.dense(b, 3, use_bias=False)
        x = core_layers.dense(b, 3, use_bias=False) + c

        # Only the variables used in F.
        y = F(x)

        all_vars = vs.global_variables()
        assert len(all_vars) == 4
      grads = gradients.gradients(y, [x] + all_vars)
      _, var_grads = grads[0], grads[1:]
      for g in grads:
        self.assertIsNotNone(g)
      with session.Session() as sess:
        self.evaluate(variables.global_variables_initializer())
        dw = sess.run(math_ops.reduce_sum(var_grads[-1]))
        self.assertEqual(9., dw)

  def testCustomGradientWithVariablesEager(self):
    with context.eager_mode():
      layer = core_layers.Dense(4, use_bias=False)

      @custom_gradient.custom_gradient
      def F(x):
        out = layer(x)

        def Grad(out_grad, variables=None):  # pylint: disable=redefined-outer-name
          del out_grad
          self.assertEqual(1, len(variables))
          return (array_ops.ones((3, 2)),
                  [array_ops.ones((2, 4))])

        return out, Grad

      x = array_ops.ones((3, 2)) + 2.
      with backprop.GradientTape() as tape:
        tape.watch(x)
        y = F(x)
      w, = layer.variables
      dx, dw = tape.gradient(y, [x, w])
      self.assertEqual(6., math_ops.reduce_sum(dx).numpy())
      self.assertEqual(8., math_ops.reduce_sum(dw).numpy())

  @test_util.run_v1_only("b/120545219")
  def testCustomGradientErrorsWithNonResourceVariables(self):

    def F(x, use_resource=False):
      with variable_scope.variable_scope("f", use_resource=use_resource):
        out = core_layers.dense(x, 4, use_bias=False)

      def Grad(out_grad, variables=None):  # pylint: disable=redefined-outer-name
        del out_grad
        self.assertEqual(1, len(variables))
        return (array_ops.ones((3, 2)), [array_ops.ones((2, 4))])

      return out, Grad

    @custom_gradient.custom_gradient
    def FResource(x):
      return F(x, use_resource=True)

    @custom_gradient.custom_gradient
    def FNonResource(x):
      return F(x, use_resource=False)

    x = array_ops.ones((3, 2)) + 2.

    # Wrapping scope has use_resource=True but inner scope sets to False. Fails.
    with variable_scope.variable_scope("vs1", use_resource=True):
      with self.assertRaisesWithPredicateMatch(TypeError,
                                               "must be `ResourceVariable`s"):
        FNonResource(x)

    # Wrapping scope has use_resource=False but inner scope sets to True.
    # Passes.
    with variable_scope.variable_scope("vs2", use_resource=False):
      FResource(x)

  def testWithNumpyInputs(self):
    with context.eager_mode():

      @custom_gradient.custom_gradient
      def F(x):
        out = x

        def Grad(_):
          return (None, None)

        return out, Grad

      x = np.ones((3, 2), dtype=np.float32)
      # Smoke test to ensure numpy inputs are accepted
      F(x)

  @test_util.run_v1_only("b/120545219")
  def testRVGradientsDynamicCond(self):
    with self.cached_session():
      alpha = resource_variable_ops.ResourceVariable(
          np.random.random((1,)),
          dtype="float32")

      conditional = array_ops.placeholder_with_default(True, shape=())
      output = control_flow_ops.cond(
          conditional, lambda: alpha * 2, lambda: alpha * 3)

      g, = gradients_impl.gradients(output, alpha)
      self.evaluate(variables.global_variables_initializer())
      self.assertAllEqual(g.eval(), [2.0])
      self.assertAllEqual(g.eval(feed_dict={conditional: False}), [3.0])

  def testRecursiveCustomGradient(self):
    @custom_gradient.custom_gradient
    def F(x):
      out = core_layers.dense(x, 3, use_bias=False)

      def Grad(out_grad, variables=None):  # pylint: disable=redefined-outer-name
        self.assertEqual(1, len(variables))
        grads = gradients.gradients(out, [x, variables[0]], grad_ys=out_grad)
        return grads[0], [array_ops.ones((4, 3))]

      return out, Grad

    @custom_gradient.custom_gradient
    def DoubleF(x):
      out = F(x)

      def Grad(out_grad, variables=None):  # pylint: disable=redefined-outer-name
        self.assertEqual(1, len(variables))
        grads = gradients.gradients(out, [x, variables[0]], grad_ys=out_grad)
        return grads[0], [array_ops.ones((4, 3))]

      return out, Grad
    with ops.Graph().as_default():
      x = array_ops.ones((2, 4))
      with variable_scope.variable_scope("f", use_resource=True) as vs:
        y = DoubleF(x)
        all_vars = vs.global_variables()
        assert len(all_vars) == 1
      grads = gradients.gradients(y, [x, all_vars[0]])
      for g in grads:
        self.assertIsNotNone(g)
      with session.Session() as sess:
        self.evaluate(variables.global_variables_initializer())
        dw = sess.run(math_ops.reduce_sum(grads[1]))
        self.assertEqual(12., dw)


class TensorListGradientsTest(test_util.TensorFlowTestCase):

  def testDefaultGradYs(self):
    with ops.Graph().as_default():
      tl = list_ops.empty_tensor_list(
          element_dtype=dtypes.float32,
          element_shape=ops.convert_to_tensor([], dtype=dtypes.int32))
      a = constant(1.0)
      tl = list_ops.tensor_list_push_back(tl, a)

      grad_tl = list_ops.empty_tensor_list(
          element_dtype=dtypes.float32,
          element_shape=ops.convert_to_tensor([], dtype=dtypes.int32))
      grad_tl = list_ops.tensor_list_push_back(tl, constant(5.0))

      grad = gradients.gradients(tl, a, grad_ys=grad_tl)[0]
      with self.cached_session() as sess:
        self.assertEquals(self.evaluate(grad), 5.)


class TestKerasModelClass(training.Model):
  """A simple tensorflow keras Model class definition."""

  def __init__(self, width):
    super(TestKerasModelClass, self).__init__()

    self.weight = variable_scope.get_variable(
        name="test_keras_var",
        shape=width,
        dtype=dtypes.float32,
        trainable=True,
        use_resource=True,
    )

  def call(self, inputs):
    return self.weight * inputs


class VariablesGradientTest(test_util.TensorFlowTestCase):

  def _TestVariablesGradient(self, inputs, test_model, vars_to_grad):
    """Returns gradients of `test_model` with respect to `vars_to_grad`."""

    test_model_re = custom_gradient.recompute_grad(test_model)

    with backprop.GradientTape(persistent=True) as tape:
      tape.watch(vars_to_grad)
      out_re = test_model_re(inputs)
      out = test_model(inputs)

    grads_re = tape.gradient(out_re, vars_to_grad)
    grads = tape.gradient(out, vars_to_grad)

    return grads_re, grads

  def _TestFnVariablesGradient(self, inputs, test_fn, vars_to_grad):
    """Returns gradients of `test_model` with respect to `vars_to_grad`."""

    test_fn_re = custom_gradient.recompute_grad(test_fn)

    with backprop.GradientTape(persistent=True) as tape:
      tape.watch(vars_to_grad)
      out_re = test_fn_re(inputs, vars_to_grad)
      out = test_fn(inputs, vars_to_grad)

    grads_re = tape.gradient(out_re, vars_to_grad)
    grads = tape.gradient(out, vars_to_grad)

    return grads_re, grads

  @test_util.run_in_graph_and_eager_modes
  def testKerasRecompute(self):
    """Checks that recompute_grad works for a simple Keras Model."""

    test_model = TestKerasModelClass(10)
    test_input = constant(np.zeros((10, 10), dtype=np.float32))
    self.evaluate(variables.global_variables_initializer())
    test_model(test_input)  # Ensures keras model is initialized.
    grads_re, grads = self._TestVariablesGradient(test_input, test_model,
                                                  test_input)

    grads_re = self.evaluate(grads_re)
    grads = self.evaluate(grads)
    for g, g_re in zip(grads, grads_re):
      self.assertAllClose(g, g_re)

    grads_re, grads = self._TestVariablesGradient(test_input, test_model,
                                                  test_model.variables)

    grads_re = self.evaluate(grads_re)
    grads = self.evaluate(grads)
    for g, g_re in zip(grads, grads_re):
      self.assertAllClose(g, g_re)

  @test_util.run_in_graph_and_eager_modes
  def testFnRecompute(self):
    """Checks that recompute_grad works grads of function args."""

    def TestFn(inputs, input_vars):
      return inputs * input_vars

    def TestFnSeq(inputs, input_vars):
      return (inputs * input_vars, inputs * input_vars * 2.0)

    with variable_scope.variable_scope("test", use_resource=True):
      test_var = variable_scope.get_variable(
          name="test_var",
          shape=10,
          trainable=True,
      )

      test_input = constant(np.zeros((10, 10), dtype=np.float32))

      grads_re, grads = self._TestFnVariablesGradient(test_input, TestFn,
                                                      test_input)

      grads_re = self.evaluate(grads_re)
      grads = self.evaluate(grads)
      for g, g_re in zip(grads, grads_re):
        self.assertAllClose(g, g_re)

      grads_re, grads = self._TestFnVariablesGradient(test_input, TestFn,
                                                      test_var)
      grads_re = self.evaluate(grads_re)
      grads = self.evaluate(grads)
      for g, g_re in zip(grads, grads_re):
        self.assertAllClose(g, g_re)

      # Regression test for wrapping sequence outputting functions.
      grads_re, grads = self._TestFnVariablesGradient(test_input, TestFnSeq,
                                                      test_input)
      grads_re = self.evaluate(grads_re)
      grads = self.evaluate(grads)
      for g, g_re in zip(grads, grads_re):
        self.assertAllClose(g, g_re)

      grads_re, grads = self._TestFnVariablesGradient(test_input, TestFnSeq,
                                                      test_var)
      grads_re = self.evaluate(grads_re)
      grads = self.evaluate(grads)
      for g, g_re in zip(grads, grads_re):
        self.assertAllClose(g, g_re)

  @test_util.deprecated_graph_mode_only
  def testFnRecomputeWithScopeGradientTape(self):
    """Checks that recompute_grad works with var scope and GradientTape."""

    def TestFn(input_t):
      with variable_scope.variable_scope("inner_scope"):
        test_var = variable_scope.get_variable(
            name="test_var",
            shape=10,
            trainable=True,
        )
        return input_t * test_var

    test_input_t = constant(np.zeros((10, 10), dtype=np.float32))

    with variable_scope.variable_scope(
        "output_scope", reuse=variable_scope.AUTO_REUSE, use_resource=True):
      test_fn_re = custom_gradient.recompute_grad(TestFn)

      with backprop.GradientTape(persistent=True) as tape:
        out_re = test_fn_re(test_input_t)
        out = TestFn(test_input_t)

    grads_re = tape.gradient(out_re, variables.trainable_variables())
    grads = tape.gradient(out, variables.trainable_variables())

    grads_re = self.evaluate(grads_re)
    grads = self.evaluate(grads)
    for g, g_re in zip(grads, grads_re):
      self.assertAllClose(g, g_re)
      self.assertAllClose(g, g_re)

  @test_util.deprecated_graph_mode_only
  def testFnRecomputeWithScopeGradients(self):
    """Checks that recompute_grad works with var scope and gradients(..)."""

    def TestFn(input_t):
      with variable_scope.variable_scope("inner_scope"):
        test_var = variable_scope.get_variable(
            name="test_var",
            shape=10,
            trainable=True,
        )
        return input_t * test_var

    test_input_t = constant(np.zeros((10, 10), dtype=np.float32))

    with variable_scope.variable_scope(
        "output_scope", reuse=variable_scope.AUTO_REUSE, use_resource=True):
      test_fn_re = custom_gradient.recompute_grad(TestFn)
      out_re = test_fn_re(test_input_t)
      out = TestFn(test_input_t)

    grads_re = gradients.gradients(out_re, variables.trainable_variables())
    grads = gradients.gradients(out, variables.trainable_variables())

    grads_re = self.evaluate(grads_re)
    grads = self.evaluate(grads)
    for g, g_re in zip(grads, grads_re):
      self.assertAllClose(g, g_re)
      self.assertAllClose(g, g_re)

  @test_util.run_in_graph_and_eager_modes
  def testFnRecomputeSameTensor(self):
    """Check recompute_grad when wrapped f called as f(x, x) - b/147369366."""

    def TestFnMul(x, y):
      return x * y

    def TestFnSingleVar(x, y):
      # pylint: disable=unused-argument
      return x

    with variable_scope.variable_scope("test", use_resource=True):
      x = array_ops.ones((10))

      grads_re, grads = self._TestFnVariablesGradient(x, TestFnMul,
                                                      x)
      grads_re = self.evaluate(grads_re)
      grads = self.evaluate(grads)
      for g, g_re in zip(grads, grads_re):
        self.assertAllClose(g, g_re)

      grads_re, grads = self._TestFnVariablesGradient(x, TestFnSingleVar,
                                                      x)
      grads_re = self.evaluate(grads_re)
      grads = self.evaluate(grads)
      for g, g_re in zip(grads, grads_re):
        self.assertAllClose(g, g_re)


class GradPassThroughTest(test_util.TensorFlowTestCase):

  @test_util.run_v1_only("b/120545219")
  def test_gradients_v1(self):
    x = variable_scope.get_variable(
        name="x", shape=(), initializer=init_ops.constant_initializer(1.0),
        use_resource=True)
    z = variable_scope.get_variable(
        name="z", shape=(), initializer=init_ops.constant_initializer(3.0),
        use_resource=True)

    # Verify that assign op is not differentiable
    y = state_ops.assign(x, z**2)
    grads = gradients.gradients(y, z)
    self.assertIsNone(grads[0])

    # Verify that when the (non differentiable) assign op is wrapped with
    # grad_pass_through, gradients are correctly forwarded to the inputs.
    # Form an input as quadratic function of variable z and check that the
    # gradient of output wrt to z is correct.
    y = custom_gradient.grad_pass_through(
        lambda v: state_ops.assign(x, v))(z**2)
    grads = gradients.gradients(y, z)
    with self.cached_session() as sess:
      sess.run(variables.global_variables_initializer())
      self.assertAllClose(grads[0].eval(), 6.0)

    # Verify that variables involved in the wrapped op do not receive gradients.
    y = custom_gradient.grad_pass_through(lambda v: x * v)(z)
    grads = gradients.gradients(y, x)
    self.assertIsNone(grads[0])

  @test_util.run_v2_only
  def test_gradients_v2(self):
    x = variables.Variable(1.0, name="x")
    z = variables.Variable(3.0, name="z")

    # Verify that assign op is not differentiable
    with backprop.GradientTape() as tape:
      y = x.assign(z**2)
    grads = tape.gradient(y, z)
    self.assertIsNone(grads)

    # Verify that when the (non differentiable) assign op is wrapped with
    # grad_pass_through, gradients are correctly forwarded to the inputs.
    # Form an input as quadratic function of variable z and check that the
    # gradient of output wrt to z is correct.
    with backprop.GradientTape() as tape:
      y = custom_gradient.grad_pass_through(x.assign)(z**2)
    grads = tape.gradient(y, z)
    self.assertAllClose(grads, 6.0)

    # Verify that variables involved in the wrapped op do not receive gradients.
    with backprop.GradientTape() as tape:
      y = custom_gradient.grad_pass_through(lambda v: x * v)(z)
    grads = tape.gradient(y, x)
    self.assertIsNone(grads)


if __name__ == "__main__":
  googletest.main()