1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for compute_gradient.""" 16 17import numpy as np 18 19from tensorflow.python.eager import backprop 20from tensorflow.python.framework import constant_op 21from tensorflow.python.framework import dtypes 22from tensorflow.python.framework import sparse_tensor 23from tensorflow.python.framework import test_util 24from tensorflow.python.ops import array_ops 25from tensorflow.python.ops import custom_gradient 26from tensorflow.python.ops import \ 27gradient_checker_v2 as gradient_checker 28from tensorflow.python.ops import math_ops 29from tensorflow.python.ops import nn_ops 30from tensorflow.python.ops import sparse_ops 31# needs this to register gradient for SoftmaxCrossEntropyWithLogits: 32import tensorflow.python.ops.nn_grad # pylint: disable=unused-import 33from tensorflow.python.platform import test 34from tensorflow.python.platform import tf_logging 35 36 37def _random_complex(shape, dtype): 38 data = np.random.random_sample(shape).astype(dtype.as_numpy_dtype) 39 if dtype.is_complex: 40 data.imag = np.random.random_sample(shape) 41 return data 42 43 44@test_util.run_all_in_graph_and_eager_modes 45class GradientCheckerTest(test.TestCase): 46 47 def testSparseTensorReshape(self): 48 x = constant_op.constant(2.0, shape=(2,)) 49 50 def sparse_tensor_reshape(values): 51 sparse = sparse_tensor.SparseTensor( 52 indices=[[0, 0], [1, 2]], values=values, dense_shape=[3, 4]) 53 sparse = sparse_ops.sparse_reshape(sparse, shape=(12,)) 54 return sparse.values 55 56 error = gradient_checker.max_error( 57 *gradient_checker.compute_gradient(sparse_tensor_reshape, [x])) 58 59 self.assertLess(error, 1e-4) 60 61 def testWithStaticShape(self): 62 size = (2, 3) 63 constant = constant_op.constant(2.0, shape=size, name="const") 64 65 def add_constant_with_static_shape_check(x): 66 self.assertAllEqual(x.shape.as_list(), constant.shape.as_list()) 67 return x + constant 68 69 x = constant_op.constant(3.0, shape=size, name="x") 70 71 error = gradient_checker.max_error(*gradient_checker.compute_gradient( 72 add_constant_with_static_shape_check, [x])) 73 74 self.assertLess(error, 1e-4) 75 76 def testWithArgumentsAsTuple(self): 77 size = (2, 3) 78 x1 = constant_op.constant(2.0, shape=size, name="x1") 79 x2 = constant_op.constant(3.0, shape=size, name="x2") 80 81 error = gradient_checker.max_error(*gradient_checker.compute_gradient( 82 lambda x1: math_ops.add(x1, x2), (x1,))) 83 84 tf_logging.info("x1 error = %f", error) 85 self.assertLess(error, 1e-4) 86 87 def testAddSimple(self): 88 size = (2, 3) 89 x1 = constant_op.constant(2.0, shape=size, name="x1") 90 x2 = constant_op.constant(3.0, shape=size, name="x2") 91 error = gradient_checker.max_error(*gradient_checker.compute_gradient( 92 lambda x1: math_ops.add(x1, x2), [x1])) 93 tf_logging.info("x1 error = %f", error) 94 self.assertLess(error, 1e-4) 95 96 def testBfloat16(self): 97 x1 = constant_op.constant(2.0, dtype="bfloat16") 98 x2 = constant_op.constant(3.0, dtype="bfloat16") 99 # bfloat16 is very imprecise, so we use very large delta and error bar here. 100 error = gradient_checker.max_error(*gradient_checker.compute_gradient( 101 lambda x1: math_ops.add(x1, x2), [x1], delta=0.1)) 102 tf_logging.info("x1 error = %f", error) 103 self.assertLess(error, 0.07) 104 105 def testAddCustomized(self): 106 size = (2, 3) 107 x1 = constant_op.constant(2.0, shape=size, dtype=dtypes.float64, name="x1") 108 x2 = np.asarray(np.arange(6, dtype=np.float64).reshape(2, 3)) 109 # checkint gradients for x2 using a special delta 110 error = gradient_checker.max_error(*gradient_checker.compute_gradient( 111 lambda x2: math_ops.add(x1, x2), [x2], delta=1e-2)) 112 tf_logging.info("x2 error = %f", error) 113 self.assertLess(error, 1e-10) 114 115 def testGather(self): 116 117 def f(params): 118 index_values = [1, 3] 119 indices = constant_op.constant(index_values, name="i") 120 return array_ops.gather(params, indices, name="y") 121 122 p_shape = (4, 2) 123 p_size = 8 124 params = constant_op.constant( 125 np.arange(p_size).astype(np.float64), shape=p_shape, name="p") 126 error = gradient_checker.max_error( 127 *gradient_checker.compute_gradient(f, [params])) 128 tf_logging.info("gather error = %f", error) 129 self.assertLess(error, 1e-4) 130 131 def testNestedGather(self): 132 133 def f(params): 134 index_values = [1, 3, 5, 6] 135 indices = constant_op.constant(index_values, name="i") 136 y = array_ops.gather(params, indices, name="y") 137 index_values2 = [0, 2] 138 indices2 = constant_op.constant(index_values2, name="i2") 139 return array_ops.gather(y, indices2, name="y2") 140 141 p_shape = (8, 2) 142 p_size = 16 143 params = constant_op.constant( 144 np.arange(p_size).astype(np.float64), shape=p_shape, name="p") 145 error = gradient_checker.max_error( 146 *gradient_checker.compute_gradient(f, [params])) 147 tf_logging.info("nested gather error = %f", error) 148 self.assertLess(error, 1e-4) 149 150 def testComplexMul(self): 151 c = constant_op.constant(5 + 7j, dtype=dtypes.complex64) 152 153 def f(x): 154 return c * x 155 156 x_shape = c.shape 157 x_dtype = c.dtype 158 x = constant_op.constant(_random_complex(x_shape, x_dtype)) 159 analytical, numerical = gradient_checker.compute_gradient(f, [x]) 160 correct = np.array([[5, -7], [7, 5]]) 161 self.assertAllEqual(correct, analytical[0]) 162 self.assertAllClose(correct, numerical[0], rtol=1e-4) 163 x = constant_op.constant(_random_complex(x_shape, x_dtype)) 164 self.assertLess( 165 gradient_checker.max_error(*gradient_checker.compute_gradient(f, [x])), 166 3e-4) 167 168 def testComplexConj(self): 169 170 def f(x): 171 return math_ops.conj(x) 172 173 x_shape = () 174 x_dtype = dtypes.complex64 175 x = constant_op.constant(_random_complex(x_shape, x_dtype)) 176 analytical, numerical = gradient_checker.compute_gradient(f, [x]) 177 correct = np.array([[1, 0], [0, -1]]) 178 self.assertAllEqual(correct, analytical[0]) 179 self.assertAllClose(correct, numerical[0], rtol=2e-5) 180 x = constant_op.constant(_random_complex(x_shape, x_dtype)) 181 self.assertLess( 182 gradient_checker.max_error(*gradient_checker.compute_gradient(f, [x])), 183 2e-5) 184 185 def testEmptySucceeds(self): 186 187 def f(x): 188 return array_ops.identity(x) 189 190 x = constant_op.constant( 191 np.random.random_sample((0, 3)), dtype=dtypes.float32) 192 for grad in gradient_checker.compute_gradient(f, [x]): 193 self.assertEqual(grad[0].shape, (0, 0)) 194 error = gradient_checker.max_error( 195 *gradient_checker.compute_gradient(f, [x])) 196 self.assertEqual(error, 0) 197 198 def testEmptyMatMul(self): 199 200 def f(x, y): 201 return math_ops.matmul(x, y) 202 203 x = constant_op.constant( 204 np.random.random_sample((0, 3)), dtype=dtypes.float32) 205 y = constant_op.constant( 206 np.random.random_sample((3, 4)), dtype=dtypes.float32) 207 for grad in gradient_checker.compute_gradient(f, [x, y]): 208 self.assertEqual(grad[0].shape, (0, 0)) 209 self.assertEqual(grad[1].shape, (0, 12)) 210 error = gradient_checker.max_error( 211 *gradient_checker.compute_gradient(f, [x, y])) 212 self.assertEqual(error, 0) 213 214 def testEmptyFails(self): 215 216 @custom_gradient.custom_gradient 217 def id_bad_grad(x): 218 y = array_ops.identity(x) 219 220 def grad_fn(dy): 221 # dx = constant_op.constant(np.zeros((1, 4)), dtype=dtypes.float32) 222 dx = array_ops.transpose(dy) 223 return dx 224 225 return y, grad_fn 226 227 def f(x): 228 return id_bad_grad(x) 229 230 x = constant_op.constant( 231 np.random.random_sample((0, 3)), dtype=dtypes.float32) 232 bad = r"Empty gradient has wrong shape: expected \(0, 3\), got \(3, 0\)" 233 with self.assertRaisesRegex(ValueError, bad): 234 gradient_checker.compute_gradient(f, [x]) 235 236 def testNaNGradFails(self): 237 238 @custom_gradient.custom_gradient 239 def id_nan_grad(x): 240 y = array_ops.identity(x) 241 242 def grad_fn(dy): 243 dx = np.nan * dy 244 # dx = dy 245 return dx 246 247 return y, grad_fn 248 249 def f(x): 250 return id_nan_grad(x) 251 252 x = constant_op.constant( 253 np.random.random_sample((1, 1)), dtype=dtypes.float32) 254 error = gradient_checker.max_error( 255 *gradient_checker.compute_gradient(f, [x])) 256 # Typical test would assert error < max_err, so assert this test would 257 # raise AssertionError, since NaN is not < 1.0. 258 with self.assertRaisesRegex(AssertionError, "nan not less than 1.0"): 259 self.assertLess(error, 1.0) 260 261 def testGradGrad(self): 262 263 def f(x): 264 with backprop.GradientTape() as tape: 265 tape.watch(x) 266 y = math_ops.square(x) 267 z = math_ops.square(y) 268 return tape.gradient(z, x) 269 270 analytical, numerical = gradient_checker.compute_gradient(f, [2.0]) 271 self.assertAllEqual([[[48.]]], analytical) 272 self.assertAllClose([[[48.]]], numerical, rtol=1e-4) 273 274 275@test_util.run_all_in_graph_and_eager_modes 276class MiniMNISTTest(test.TestCase): 277 278 # Gradient checker for MNIST. 279 def _BuildAndTestMiniMNIST(self, param_index, tag): 280 # Fix seed to avoid occasional flakiness 281 np.random.seed(6) 282 283 # Hyperparameters 284 batch = 3 285 inputs = 16 286 features = 32 287 classes = 10 288 289 # Define the parameters 290 inp_data = np.random.random_sample(inputs * batch) 291 hidden_weight_data = np.random.randn(inputs * features) / np.sqrt(inputs) 292 hidden_bias_data = np.random.random_sample(features) 293 sm_weight_data = np.random.randn(features * classes) / np.sqrt(features) 294 sm_bias_data = np.random.random_sample(classes) 295 296 # special care for labels since they need to be normalized per batch 297 label_data = np.random.random(batch * classes).reshape((batch, classes)) 298 s = label_data.sum(axis=1) 299 label_data /= s[:, None] 300 301 # We treat the inputs as "parameters" here 302 inp = constant_op.constant( 303 inp_data.tolist(), 304 shape=[batch, inputs], 305 dtype=dtypes.float64, 306 name="inp") 307 hidden_weight = constant_op.constant( 308 hidden_weight_data.tolist(), 309 shape=[inputs, features], 310 dtype=dtypes.float64, 311 name="hidden_weight") 312 hidden_bias = constant_op.constant( 313 hidden_bias_data.tolist(), 314 shape=[features], 315 dtype=dtypes.float64, 316 name="hidden_bias") 317 softmax_weight = constant_op.constant( 318 sm_weight_data.tolist(), 319 shape=[features, classes], 320 dtype=dtypes.float64, 321 name="softmax_weight") 322 softmax_bias = constant_op.constant( 323 sm_bias_data.tolist(), 324 shape=[classes], 325 dtype=dtypes.float64, 326 name="softmax_bias") 327 328 # List all the parameter so that we can test them one at a time 329 all_params = [inp, hidden_weight, hidden_bias, softmax_weight, softmax_bias] 330 331 # Now, Building MNIST 332 def f(inp, hidden_weight, hidden_bias, softmax_weight, softmax_bias): 333 features = nn_ops.relu( 334 nn_ops.xw_plus_b(inp, hidden_weight, hidden_bias), name="features") 335 logits = nn_ops.xw_plus_b( 336 features, softmax_weight, softmax_bias, name="logits") 337 labels = constant_op.constant( 338 label_data.tolist(), 339 shape=[batch, classes], 340 dtype=dtypes.float64, 341 name="labels") 342 cost = nn_ops.softmax_cross_entropy_with_logits( 343 labels=labels, logits=logits, name="cost") 344 return cost 345 346 def f_restricted(x): 347 xs = all_params 348 i = param_index 349 # use x for the i-th parameter 350 xs = xs[0:i] + [x] + xs[i + 1:] 351 return f(*xs) 352 353 # Test the gradients. 354 err = gradient_checker.max_error(*gradient_checker.compute_gradient( 355 f_restricted, [all_params[param_index]], delta=1e-5)) 356 357 tf_logging.info("Mini MNIST: %s gradient error = %g", tag, err) 358 return err 359 360 def testInputGradient(self): 361 self.assertLess(self._BuildAndTestMiniMNIST(0, "input"), 1e-8) 362 363 def testHiddenWeightGradient(self): 364 self.assertLess(self._BuildAndTestMiniMNIST(1, "hidden_weight"), 1e-8) 365 366 def testHiddenBiasGradient(self): 367 self.assertLess(self._BuildAndTestMiniMNIST(2, "hidden_bias"), 1e-8) 368 369 def testSoftmaxWeightGradient(self): 370 self.assertLess(self._BuildAndTestMiniMNIST(3, "softmax_weight"), 1e-8) 371 372 def testSoftmaxBiasGradient(self): 373 self.assertLess(self._BuildAndTestMiniMNIST(4, "softmax_bias"), 1e-8) 374 375 376if __name__ == "__main__": 377 test.main() 378