• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for compute_gradient."""
16
17import numpy as np
18
19from tensorflow.python.eager import backprop
20from tensorflow.python.framework import constant_op
21from tensorflow.python.framework import dtypes
22from tensorflow.python.framework import sparse_tensor
23from tensorflow.python.framework import test_util
24from tensorflow.python.ops import array_ops
25from tensorflow.python.ops import custom_gradient
26from tensorflow.python.ops import \
27gradient_checker_v2 as gradient_checker
28from tensorflow.python.ops import math_ops
29from tensorflow.python.ops import nn_ops
30from tensorflow.python.ops import sparse_ops
31# needs this to register gradient for SoftmaxCrossEntropyWithLogits:
32import tensorflow.python.ops.nn_grad  # pylint: disable=unused-import
33from tensorflow.python.platform import test
34from tensorflow.python.platform import tf_logging
35
36
37def _random_complex(shape, dtype):
38  data = np.random.random_sample(shape).astype(dtype.as_numpy_dtype)
39  if dtype.is_complex:
40    data.imag = np.random.random_sample(shape)
41  return data
42
43
44@test_util.run_all_in_graph_and_eager_modes
45class GradientCheckerTest(test.TestCase):
46
47  def testSparseTensorReshape(self):
48    x = constant_op.constant(2.0, shape=(2,))
49
50    def sparse_tensor_reshape(values):
51      sparse = sparse_tensor.SparseTensor(
52          indices=[[0, 0], [1, 2]], values=values, dense_shape=[3, 4])
53      sparse = sparse_ops.sparse_reshape(sparse, shape=(12,))
54      return sparse.values
55
56    error = gradient_checker.max_error(
57        *gradient_checker.compute_gradient(sparse_tensor_reshape, [x]))
58
59    self.assertLess(error, 1e-4)
60
61  def testWithStaticShape(self):
62    size = (2, 3)
63    constant = constant_op.constant(2.0, shape=size, name="const")
64
65    def add_constant_with_static_shape_check(x):
66      self.assertAllEqual(x.shape.as_list(), constant.shape.as_list())
67      return x + constant
68
69    x = constant_op.constant(3.0, shape=size, name="x")
70
71    error = gradient_checker.max_error(*gradient_checker.compute_gradient(
72        add_constant_with_static_shape_check, [x]))
73
74    self.assertLess(error, 1e-4)
75
76  def testWithArgumentsAsTuple(self):
77    size = (2, 3)
78    x1 = constant_op.constant(2.0, shape=size, name="x1")
79    x2 = constant_op.constant(3.0, shape=size, name="x2")
80
81    error = gradient_checker.max_error(*gradient_checker.compute_gradient(
82        lambda x1: math_ops.add(x1, x2), (x1,)))
83
84    tf_logging.info("x1 error = %f", error)
85    self.assertLess(error, 1e-4)
86
87  def testAddSimple(self):
88    size = (2, 3)
89    x1 = constant_op.constant(2.0, shape=size, name="x1")
90    x2 = constant_op.constant(3.0, shape=size, name="x2")
91    error = gradient_checker.max_error(*gradient_checker.compute_gradient(
92        lambda x1: math_ops.add(x1, x2), [x1]))
93    tf_logging.info("x1 error = %f", error)
94    self.assertLess(error, 1e-4)
95
96  def testBfloat16(self):
97    x1 = constant_op.constant(2.0, dtype="bfloat16")
98    x2 = constant_op.constant(3.0, dtype="bfloat16")
99    # bfloat16 is very imprecise, so we use very large delta and error bar here.
100    error = gradient_checker.max_error(*gradient_checker.compute_gradient(
101        lambda x1: math_ops.add(x1, x2), [x1], delta=0.1))
102    tf_logging.info("x1 error = %f", error)
103    self.assertLess(error, 0.07)
104
105  def testAddCustomized(self):
106    size = (2, 3)
107    x1 = constant_op.constant(2.0, shape=size, dtype=dtypes.float64, name="x1")
108    x2 = np.asarray(np.arange(6, dtype=np.float64).reshape(2, 3))
109    # checkint gradients for x2 using a special delta
110    error = gradient_checker.max_error(*gradient_checker.compute_gradient(
111        lambda x2: math_ops.add(x1, x2), [x2], delta=1e-2))
112    tf_logging.info("x2 error = %f", error)
113    self.assertLess(error, 1e-10)
114
115  def testGather(self):
116
117    def f(params):
118      index_values = [1, 3]
119      indices = constant_op.constant(index_values, name="i")
120      return array_ops.gather(params, indices, name="y")
121
122    p_shape = (4, 2)
123    p_size = 8
124    params = constant_op.constant(
125        np.arange(p_size).astype(np.float64), shape=p_shape, name="p")
126    error = gradient_checker.max_error(
127        *gradient_checker.compute_gradient(f, [params]))
128    tf_logging.info("gather error = %f", error)
129    self.assertLess(error, 1e-4)
130
131  def testNestedGather(self):
132
133    def f(params):
134      index_values = [1, 3, 5, 6]
135      indices = constant_op.constant(index_values, name="i")
136      y = array_ops.gather(params, indices, name="y")
137      index_values2 = [0, 2]
138      indices2 = constant_op.constant(index_values2, name="i2")
139      return array_ops.gather(y, indices2, name="y2")
140
141    p_shape = (8, 2)
142    p_size = 16
143    params = constant_op.constant(
144        np.arange(p_size).astype(np.float64), shape=p_shape, name="p")
145    error = gradient_checker.max_error(
146        *gradient_checker.compute_gradient(f, [params]))
147    tf_logging.info("nested gather error = %f", error)
148    self.assertLess(error, 1e-4)
149
150  def testComplexMul(self):
151    c = constant_op.constant(5 + 7j, dtype=dtypes.complex64)
152
153    def f(x):
154      return c * x
155
156    x_shape = c.shape
157    x_dtype = c.dtype
158    x = constant_op.constant(_random_complex(x_shape, x_dtype))
159    analytical, numerical = gradient_checker.compute_gradient(f, [x])
160    correct = np.array([[5, -7], [7, 5]])
161    self.assertAllEqual(correct, analytical[0])
162    self.assertAllClose(correct, numerical[0], rtol=1e-4)
163    x = constant_op.constant(_random_complex(x_shape, x_dtype))
164    self.assertLess(
165        gradient_checker.max_error(*gradient_checker.compute_gradient(f, [x])),
166        3e-4)
167
168  def testComplexConj(self):
169
170    def f(x):
171      return math_ops.conj(x)
172
173    x_shape = ()
174    x_dtype = dtypes.complex64
175    x = constant_op.constant(_random_complex(x_shape, x_dtype))
176    analytical, numerical = gradient_checker.compute_gradient(f, [x])
177    correct = np.array([[1, 0], [0, -1]])
178    self.assertAllEqual(correct, analytical[0])
179    self.assertAllClose(correct, numerical[0], rtol=2e-5)
180    x = constant_op.constant(_random_complex(x_shape, x_dtype))
181    self.assertLess(
182        gradient_checker.max_error(*gradient_checker.compute_gradient(f, [x])),
183        2e-5)
184
185  def testEmptySucceeds(self):
186
187    def f(x):
188      return array_ops.identity(x)
189
190    x = constant_op.constant(
191        np.random.random_sample((0, 3)), dtype=dtypes.float32)
192    for grad in gradient_checker.compute_gradient(f, [x]):
193      self.assertEqual(grad[0].shape, (0, 0))
194    error = gradient_checker.max_error(
195        *gradient_checker.compute_gradient(f, [x]))
196    self.assertEqual(error, 0)
197
198  def testEmptyMatMul(self):
199
200    def f(x, y):
201      return math_ops.matmul(x, y)
202
203    x = constant_op.constant(
204        np.random.random_sample((0, 3)), dtype=dtypes.float32)
205    y = constant_op.constant(
206        np.random.random_sample((3, 4)), dtype=dtypes.float32)
207    for grad in gradient_checker.compute_gradient(f, [x, y]):
208      self.assertEqual(grad[0].shape, (0, 0))
209      self.assertEqual(grad[1].shape, (0, 12))
210    error = gradient_checker.max_error(
211        *gradient_checker.compute_gradient(f, [x, y]))
212    self.assertEqual(error, 0)
213
214  def testEmptyFails(self):
215
216    @custom_gradient.custom_gradient
217    def id_bad_grad(x):
218      y = array_ops.identity(x)
219
220      def grad_fn(dy):
221        # dx = constant_op.constant(np.zeros((1, 4)), dtype=dtypes.float32)
222        dx = array_ops.transpose(dy)
223        return dx
224
225      return y, grad_fn
226
227    def f(x):
228      return id_bad_grad(x)
229
230    x = constant_op.constant(
231        np.random.random_sample((0, 3)), dtype=dtypes.float32)
232    bad = r"Empty gradient has wrong shape: expected \(0, 3\), got \(3, 0\)"
233    with self.assertRaisesRegex(ValueError, bad):
234      gradient_checker.compute_gradient(f, [x])
235
236  def testNaNGradFails(self):
237
238    @custom_gradient.custom_gradient
239    def id_nan_grad(x):
240      y = array_ops.identity(x)
241
242      def grad_fn(dy):
243        dx = np.nan * dy
244        # dx = dy
245        return dx
246
247      return y, grad_fn
248
249    def f(x):
250      return id_nan_grad(x)
251
252    x = constant_op.constant(
253        np.random.random_sample((1, 1)), dtype=dtypes.float32)
254    error = gradient_checker.max_error(
255        *gradient_checker.compute_gradient(f, [x]))
256    # Typical test would assert error < max_err, so assert this test would
257    # raise AssertionError, since NaN is not < 1.0.
258    with self.assertRaisesRegex(AssertionError, "nan not less than 1.0"):
259      self.assertLess(error, 1.0)
260
261  def testGradGrad(self):
262
263    def f(x):
264      with backprop.GradientTape() as tape:
265        tape.watch(x)
266        y = math_ops.square(x)
267        z = math_ops.square(y)
268      return tape.gradient(z, x)
269
270    analytical, numerical = gradient_checker.compute_gradient(f, [2.0])
271    self.assertAllEqual([[[48.]]], analytical)
272    self.assertAllClose([[[48.]]], numerical, rtol=1e-4)
273
274
275@test_util.run_all_in_graph_and_eager_modes
276class MiniMNISTTest(test.TestCase):
277
278  # Gradient checker for MNIST.
279  def _BuildAndTestMiniMNIST(self, param_index, tag):
280    # Fix seed to avoid occasional flakiness
281    np.random.seed(6)
282
283    # Hyperparameters
284    batch = 3
285    inputs = 16
286    features = 32
287    classes = 10
288
289    # Define the parameters
290    inp_data = np.random.random_sample(inputs * batch)
291    hidden_weight_data = np.random.randn(inputs * features) / np.sqrt(inputs)
292    hidden_bias_data = np.random.random_sample(features)
293    sm_weight_data = np.random.randn(features * classes) / np.sqrt(features)
294    sm_bias_data = np.random.random_sample(classes)
295
296    # special care for labels since they need to be normalized per batch
297    label_data = np.random.random(batch * classes).reshape((batch, classes))
298    s = label_data.sum(axis=1)
299    label_data /= s[:, None]
300
301    # We treat the inputs as "parameters" here
302    inp = constant_op.constant(
303        inp_data.tolist(),
304        shape=[batch, inputs],
305        dtype=dtypes.float64,
306        name="inp")
307    hidden_weight = constant_op.constant(
308        hidden_weight_data.tolist(),
309        shape=[inputs, features],
310        dtype=dtypes.float64,
311        name="hidden_weight")
312    hidden_bias = constant_op.constant(
313        hidden_bias_data.tolist(),
314        shape=[features],
315        dtype=dtypes.float64,
316        name="hidden_bias")
317    softmax_weight = constant_op.constant(
318        sm_weight_data.tolist(),
319        shape=[features, classes],
320        dtype=dtypes.float64,
321        name="softmax_weight")
322    softmax_bias = constant_op.constant(
323        sm_bias_data.tolist(),
324        shape=[classes],
325        dtype=dtypes.float64,
326        name="softmax_bias")
327
328    # List all the parameter so that we can test them one at a time
329    all_params = [inp, hidden_weight, hidden_bias, softmax_weight, softmax_bias]
330
331    # Now, Building MNIST
332    def f(inp, hidden_weight, hidden_bias, softmax_weight, softmax_bias):
333      features = nn_ops.relu(
334          nn_ops.xw_plus_b(inp, hidden_weight, hidden_bias), name="features")
335      logits = nn_ops.xw_plus_b(
336          features, softmax_weight, softmax_bias, name="logits")
337      labels = constant_op.constant(
338          label_data.tolist(),
339          shape=[batch, classes],
340          dtype=dtypes.float64,
341          name="labels")
342      cost = nn_ops.softmax_cross_entropy_with_logits(
343          labels=labels, logits=logits, name="cost")
344      return cost
345
346    def f_restricted(x):
347      xs = all_params
348      i = param_index
349      # use x for the i-th parameter
350      xs = xs[0:i] + [x] + xs[i + 1:]
351      return f(*xs)
352
353    # Test the gradients.
354    err = gradient_checker.max_error(*gradient_checker.compute_gradient(
355        f_restricted, [all_params[param_index]], delta=1e-5))
356
357    tf_logging.info("Mini MNIST: %s gradient error = %g", tag, err)
358    return err
359
360  def testInputGradient(self):
361    self.assertLess(self._BuildAndTestMiniMNIST(0, "input"), 1e-8)
362
363  def testHiddenWeightGradient(self):
364    self.assertLess(self._BuildAndTestMiniMNIST(1, "hidden_weight"), 1e-8)
365
366  def testHiddenBiasGradient(self):
367    self.assertLess(self._BuildAndTestMiniMNIST(2, "hidden_bias"), 1e-8)
368
369  def testSoftmaxWeightGradient(self):
370    self.assertLess(self._BuildAndTestMiniMNIST(3, "softmax_weight"), 1e-8)
371
372  def testSoftmaxBiasGradient(self):
373    self.assertLess(self._BuildAndTestMiniMNIST(4, "softmax_bias"), 1e-8)
374
375
376if __name__ == "__main__":
377  test.main()
378