1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for batch_norm related functionality in tensorflow.ops.nn.""" 16 17import numpy as np 18 19from tensorflow.python.framework import constant_op 20from tensorflow.python.framework import dtypes 21from tensorflow.python.framework import ops 22from tensorflow.python.framework import test_util 23from tensorflow.python.ops import array_ops 24from tensorflow.python.ops import gen_nn_ops 25from tensorflow.python.ops import gradient_checker 26from tensorflow.python.ops import gradients_impl 27from tensorflow.python.ops import math_ops 28from tensorflow.python.ops import nn_impl 29import tensorflow.python.ops.nn_grad # pylint: disable=unused-import 30from tensorflow.python.platform import test 31 32 33class BatchNormalizationTest(test.TestCase): 34 35 def _npBatchNorm(self, x, m, v, beta, gamma, epsilon, 36 scale_after_normalization, shift_after_normalization): 37 y = (x - m) / np.sqrt(v + epsilon) 38 y = y * gamma if scale_after_normalization else y 39 return y + beta if shift_after_normalization else y 40 41 def _opsBatchNorm(self, x, m, v, beta, gamma, epsilon, 42 scale_after_normalization, shift_after_normalization): 43 y = (x - m) * math_ops.rsqrt(v + epsilon) 44 if scale_after_normalization: 45 y = gamma * y 46 return y + beta if shift_after_normalization else y 47 48 def _tfBatchNormV1(self, x, m, v, beta, gamma, epsilon, 49 scale_after_normalization): 50 """Original implementation.""" 51 test_util.set_producer_version(ops.get_default_graph(), 8) 52 return gen_nn_ops._batch_norm_with_global_normalization( 53 x, m, v, beta, gamma, epsilon, scale_after_normalization) 54 55 def _tfBatchNormV1BW(self, x, m, v, beta, gamma, epsilon, 56 scale_after_normalization): 57 """Re-implementation of the original kernel for backward compatibility.""" 58 return nn_impl.batch_norm_with_global_normalization( 59 x, m, v, beta, gamma, epsilon, scale_after_normalization) 60 61 def _tfBatchNormV2(self, x, m, v, beta, gamma, epsilon, 62 scale_after_normalization, shift_after_normalization): 63 """New implementation.""" 64 return nn_impl.batch_normalization(x, m, v, beta if 65 shift_after_normalization else None, 66 gamma if scale_after_normalization else 67 None, epsilon) 68 69 @test_util.run_deprecated_v1 70 def testBatchNorm(self): 71 x_shape = [3, 5, 4, 2] 72 param_shape = [2] 73 x_val = np.random.random_sample(x_shape).astype(np.float32) 74 m_val = np.random.random_sample(param_shape).astype(np.float32) 75 v_val = np.random.random_sample(param_shape).astype(np.float32) 76 beta_val = np.random.random_sample(param_shape).astype(np.float32) 77 gamma_val = np.random.random_sample(param_shape).astype(np.float32) 78 for use_gpu in [True, False]: 79 with self.cached_session(use_gpu=use_gpu) as sess: 80 x = constant_op.constant(x_val, name="x") 81 m = constant_op.constant(m_val, name="m") 82 v = constant_op.constant(v_val, name="v") 83 beta = constant_op.constant(beta_val, name="beta") 84 gamma = constant_op.constant(gamma_val, name="gamma") 85 epsilon = 0.001 86 for scale_after_normalization in [True, False]: 87 for shift_after_normalization in [True, False]: 88 bn2 = self._tfBatchNormV2(x, m, v, beta, gamma, epsilon, 89 scale_after_normalization, 90 shift_after_normalization) 91 bn1bw = self._tfBatchNormV1BW(x, m, v, beta, gamma, epsilon, 92 scale_after_normalization) 93 bn1 = self._tfBatchNormV1(x, m, v, beta, gamma, epsilon, 94 scale_after_normalization) 95 on = self._opsBatchNorm(x, m, v, beta, gamma, epsilon, 96 scale_after_normalization, 97 shift_after_normalization) 98 np_bn = self._npBatchNorm(x_val, m_val, v_val, beta_val, gamma_val, 99 epsilon, scale_after_normalization, 100 shift_after_normalization) 101 tf_bn_v2, tf_bn_v1bw, tf_bn_v1, ops_bn = sess.run( 102 [bn2, bn1bw, bn1, on]) 103 self.assertAllClose(np_bn, ops_bn, atol=0.00001) 104 self.assertAllClose(np_bn, tf_bn_v2, atol=0.00001) 105 self.assertAllClose(tf_bn_v2, ops_bn, atol=0.00001) 106 # shift_after_normalization=False is not supported in v1. 107 if shift_after_normalization: 108 self.assertAllClose(np_bn, tf_bn_v1bw, atol=0.00001) 109 self.assertAllClose(np_bn, tf_bn_v1, atol=0.00001) 110 self.assertAllClose(tf_bn_v1, ops_bn, atol=0.00001) 111 self.assertAllClose(tf_bn_v1bw, ops_bn, atol=0.00001) 112 113 def _testBatchNormGradient(self, 114 param_index, 115 tag, 116 scale_after_normalization, 117 shift_after_normalization, 118 version, 119 err_tolerance=1e-11): 120 x_shape = [3, 5, 4, 5] 121 param_shape = [5] 122 np.random.seed(1) # Make it reproducible. 123 x_val = np.random.random_sample(x_shape).astype(np.float64) 124 m_val = np.random.random_sample(param_shape).astype(np.float64) 125 v_val = np.random.random_sample(param_shape).astype(np.float64) 126 beta_val = np.random.random_sample(param_shape).astype(np.float64) 127 gamma_val = np.random.random_sample(param_shape).astype(np.float64) 128 with self.cached_session(): 129 x = constant_op.constant(x_val, name="x") 130 m = constant_op.constant(m_val, name="m") 131 v = constant_op.constant(v_val, name="v") 132 beta = constant_op.constant(beta_val, name="beta") 133 gamma = constant_op.constant(gamma_val, name="gamma") 134 epsilon = 0.001 135 if version == 1: 136 output = self._tfBatchNormV1(x, m, v, beta, gamma, epsilon, 137 scale_after_normalization) 138 elif version == 2: 139 output = self._tfBatchNormV2(x, m, v, beta, gamma, epsilon, 140 scale_after_normalization, 141 shift_after_normalization) 142 else: 143 print("Invalid version", version) 144 raise ValueError() 145 all_params = [x, m, v, beta, gamma] 146 all_shapes = [x_shape, param_shape, param_shape, param_shape, param_shape] 147 err = gradient_checker.compute_gradient_error(all_params[param_index], 148 all_shapes[param_index], 149 output, x_shape) 150 print("Batch normalization v%d %s gradient %s scale and %s shift err = " % 151 (version, tag, "with" if scale_after_normalization else "without", 152 "with" if shift_after_normalization else "without"), err) 153 self.assertLess(err, err_tolerance) 154 155 def _testBatchNormGradientInAllNeedConfigs(self, 156 param_index, 157 tag, 158 err_tolerance=1e-11): 159 for scale_after_normalization in [True, False]: 160 for shift_after_normalization in [True, False]: 161 # shift_after_normalization=False is not supported in version 1. 162 for v in ([1, 2] if shift_after_normalization else [2]): 163 self._testBatchNormGradient(param_index, tag, 164 scale_after_normalization, 165 shift_after_normalization, v, 166 err_tolerance) 167 168 @test_util.run_deprecated_v1 169 def testBatchNormInputGradient(self): 170 self._testBatchNormGradientInAllNeedConfigs(0, "x") 171 172 @test_util.run_deprecated_v1 173 def testBatchNormMeanGradient(self): 174 self._testBatchNormGradientInAllNeedConfigs(1, "mean") 175 176 @test_util.run_deprecated_v1 177 def testBatchNormVarianceGradient(self): 178 self._testBatchNormGradientInAllNeedConfigs( 179 2, "variance", err_tolerance=1e-03) 180 181 @test_util.run_deprecated_v1 182 def testBatchNormBetaGradient(self): 183 # Since beta does not exist when scale_after_normalization=False, we only 184 # test for scale_after_normalization=True. 185 for scale_after_normalization in [True, False]: 186 for v in [1, 2]: 187 self._testBatchNormGradient(3, "beta", scale_after_normalization, True, 188 v) 189 190 @test_util.run_deprecated_v1 191 def testBatchNormGammaGradient(self): 192 # If scale_after_normalization is False, backprop for gamma in v1 193 # will be 0. In version 2 of the API, if scale_after_normalization is False, 194 # gamma is not used at all, and the gradient is None, which displeases the 195 # gradient checker. 196 for scale_after_normalization in [True, False]: 197 self._testBatchNormGradient(4, "gamma", scale_after_normalization, True, 198 1) 199 for shift_after_normalization in [True, False]: 200 self._testBatchNormGradient(4, "gamma", True, shift_after_normalization, 201 2) 202 203 @test_util.run_deprecated_v1 204 def testBatchNormGradImpl(self): 205 x_shape = [7, 5, 4, 6] 206 param_shape = [6] 207 np.random.seed(1) # Make it reproducible. 208 x_val = np.random.random_sample(x_shape).astype(np.float32) 209 m_val = np.random.random_sample(param_shape).astype(np.float32) 210 v_val = np.random.random_sample(param_shape).astype(np.float32) 211 beta_val = np.random.random_sample(param_shape).astype(np.float32) 212 gamma_val = np.random.random_sample(param_shape).astype(np.float32) 213 backprop_val = np.random.random_sample(x_shape).astype(np.float32) 214 for use_gpu in [False, True]: 215 with self.cached_session(use_gpu=use_gpu) as sess: 216 x = constant_op.constant(x_val, name="x") 217 m = constant_op.constant(m_val, name="m") 218 v = constant_op.constant(v_val, name="v") 219 beta = constant_op.constant(beta_val, name="beta") 220 gamma = constant_op.constant(gamma_val, name="gamma") 221 backprop = constant_op.constant(backprop_val, name="backprop") 222 epsilon = 0.001 223 for scale_after_normalization in [True, False]: 224 # _batch_norm_with_global_normalization_grad is deprecated in v9 225 test_util.set_producer_version(ops.get_default_graph(), 8) 226 grad = gen_nn_ops.batch_norm_with_global_normalization_grad( 227 x, m, v, gamma, backprop, epsilon, scale_after_normalization) 228 dx, dm, dv, db, dg = grad 229 self.assertEqual(grad.dx, dx) 230 self.assertEqual(grad.dm, dm) 231 self.assertEqual(grad.dv, dv) 232 self.assertEqual(grad.db, db) 233 self.assertEqual(grad.dg, dg) 234 235 on = self._opsBatchNorm(x, m, v, beta, gamma, epsilon, 236 scale_after_normalization, True) 237 odx, odm, odv, odb, odg = gradients_impl.gradients( 238 [on], [x, m, v, beta, gamma], [backprop]) 239 if scale_after_normalization: 240 all_grads = self.evaluate( 241 [dx, dm, dv, db, dg, odx, odm, odv, odb, odg]) 242 to_check = ["dx", "dm", "dv", "db", "dg"] 243 else: 244 all_grads = self.evaluate([dx, dm, dv, db, odx, odm, odv, odb]) 245 to_check = ["dx", "dm", "dv", "db"] 246 for i, _ in enumerate(to_check): 247 self.assertAllClose( 248 all_grads[i + len(to_check)], all_grads[i], atol=0.000001) 249 250 @test_util.run_deprecated_v1 251 def testBatchNormKeepDims(self): 252 """Test for tf.nn.moments(..., keep_dims=True / False). 253 254 Make sure that parameters with shape (1, 1, 1, depth) yield the same 255 result as parameters with shape (depth) 256 """ 257 x_shape = (3, 5, 4, 2) 258 param_shape = (2) 259 keep_dims_param_shape = (1, 1, 1, 2) 260 x_val = np.random.random_sample(x_shape).astype(np.float32) 261 m_val = np.random.random_sample(param_shape).astype(np.float32) 262 v_val = np.random.random_sample(param_shape).astype(np.float32) 263 beta_val = np.random.random_sample(param_shape).astype(np.float32) 264 gamma_val = np.random.random_sample(param_shape).astype(np.float32) 265 for use_gpu in [True, False]: 266 with self.cached_session(use_gpu=use_gpu) as sess: 267 x = constant_op.constant(x_val, name="x") 268 m = constant_op.constant(m_val, name="m") 269 v = constant_op.constant(v_val, name="v") 270 beta = constant_op.constant(beta_val, name="beta") 271 gamma = constant_op.constant(gamma_val, name="gamma") 272 keep_dims_m = array_ops.reshape( 273 m, keep_dims_param_shape, name="keep_dims_m") 274 keep_dims_v = array_ops.reshape( 275 v, keep_dims_param_shape, name="keep_dims_v") 276 keep_dims_beta = array_ops.reshape( 277 beta, keep_dims_param_shape, name="keep_dims_beta") 278 keep_dims_gamma = array_ops.reshape( 279 gamma, keep_dims_param_shape, name="keep_dims_gamma") 280 epsilon = 0.001 281 for scale_after_normalization in [True, False]: 282 for shift_after_normalization in [True, False]: 283 bn = self._tfBatchNormV2(x, m, v, beta, gamma, epsilon, 284 scale_after_normalization, 285 shift_after_normalization) 286 keep_dims_bn = self._tfBatchNormV2(x, keep_dims_m, keep_dims_v, 287 keep_dims_beta, keep_dims_gamma, 288 epsilon, 289 scale_after_normalization, 290 shift_after_normalization) 291 tf_batch_norm, keep_dims_tf_batch_norm = sess.run( 292 [bn, keep_dims_bn]) 293 self.assertEqual(x_shape, tf_batch_norm.shape) 294 self.assertEqual(x_shape, keep_dims_tf_batch_norm.shape) 295 self.assertAllClose( 296 tf_batch_norm, keep_dims_tf_batch_norm, atol=0.000001) 297 298 def _testBatchNormArbitraryShapes(self, x_shape, param_shape, atol=0.0001, 299 dtype=dtypes.float32, 300 param_dtype=dtypes.float32): 301 numpy_dtype = dtype.as_numpy_dtype 302 numpy_param_dtype = param_dtype.as_numpy_dtype 303 x_val = np.random.random_sample(x_shape).astype(numpy_dtype) 304 m_val = np.random.random_sample(param_shape).astype(numpy_param_dtype) 305 v_val = np.random.random_sample(param_shape).astype(numpy_param_dtype) 306 beta_val = np.random.random_sample(param_shape).astype(numpy_param_dtype) 307 gamma_val = np.random.random_sample(param_shape).astype(numpy_param_dtype) 308 for use_gpu in [True, False]: 309 with self.cached_session(use_gpu=use_gpu) as sess: 310 x = constant_op.constant(x_val, name="x") 311 m = constant_op.constant(m_val, name="m") 312 v = constant_op.constant(v_val, name="v") 313 beta = constant_op.constant(beta_val, name="beta") 314 gamma = constant_op.constant(gamma_val, name="gamma") 315 epsilon = 0.001 316 for scale_after_normalization in [True, False]: 317 for shift_after_normalization in [True, False]: 318 bn = self._tfBatchNormV2(x, m, v, beta, gamma, epsilon, 319 scale_after_normalization, 320 shift_after_normalization) 321 np_batch_norm = self._npBatchNorm(x_val, m_val, v_val, beta_val, 322 gamma_val, epsilon, 323 scale_after_normalization, 324 shift_after_normalization) 325 [tf_batch_norm] = self.evaluate([bn]) 326 self.assertEqual(x_shape, np_batch_norm.shape) 327 self.assertEqual(x_shape, tf_batch_norm.shape) 328 self.assertAllClose(np_batch_norm, tf_batch_norm, atol=atol) 329 330 def testBatchNormArbitraryShapes(self): 331 """Test for a variety of shapes and moments. 332 333 Batch normalization is expected to work regardless of the position and 334 dimensionality of the 'depth' axis/axes. 335 """ 336 self._testBatchNormArbitraryShapes((3, 3), (1, 3)) 337 self._testBatchNormArbitraryShapes((3, 3), (3, 1)) 338 self._testBatchNormArbitraryShapes((3, 2, 4, 5), (1, 2, 1, 1)) 339 self._testBatchNormArbitraryShapes( 340 (2, 3, 2, 4, 5), (1, 1, 1, 4, 5), atol=0.005) 341 342 def testBatchNormMixedPrecision(self): 343 self._testBatchNormArbitraryShapes((3, 3), (1, 3), dtype=dtypes.float16, 344 param_dtype=dtypes.float32, atol=0.001) 345 346 347class SufficientStatisticsTest(test.TestCase): 348 349 def _npSuffStats(self, x, axes, shift, keep_dims): 350 axis = tuple(axes) 351 if shift is not None: 352 m_ss = np.sum(x - shift, axis=axis, keepdims=keep_dims) 353 v_ss = np.sum((x - shift) * (x - shift), axis=axis, keepdims=keep_dims) 354 else: 355 m_ss = np.sum(x, axis=axis, keepdims=keep_dims) 356 v_ss = np.sum(x * x, axis=axis, keepdims=keep_dims) 357 count = 1.0 358 for d in range(x.ndim): 359 if d in set(axes): 360 count *= x.shape[d] 361 if not keep_dims: 362 shift = np.asarray(shift) 363 return count, m_ss, v_ss, shift 364 365 def _opSuffStats(self, x, axes, shift, keep_dims): 366 return nn_impl.sufficient_statistics(x, axes, shift, keep_dims) 367 368 def _testSuffStats(self, x_shape, axes, shift, keep_dims, has_shape): 369 x_val = np.random.random_sample(x_shape).astype(np.float32) 370 np_c, np_m, np_v, np_s = self._npSuffStats(x_val, axes, shift, keep_dims) 371 for use_gpu in [True, False]: 372 with self.cached_session(use_gpu=use_gpu) as sess: 373 if has_shape: 374 x = constant_op.constant(x_val, name="x") 375 x.set_shape(x_shape) 376 op_c, op_m, op_v, op_s = self._opSuffStats(x, axes, shift, keep_dims) 377 if shift: 378 tf_c, tf_m, tf_v, tf_s = self.evaluate([op_c, op_m, op_v, op_s]) 379 else: 380 tf_c, tf_m, tf_v = self.evaluate([op_c, op_m, op_v]) 381 else: 382 x = array_ops.placeholder( 383 dtype=dtypes.float32, shape=[None] * len(x_shape), name="x") 384 op_c, op_m, op_v, op_s = self._opSuffStats(x, axes, shift, keep_dims) 385 if shift: 386 tf_c, tf_m, tf_v, tf_s = sess.run([op_c, op_m, op_v, op_s], 387 feed_dict={x: x_val}) 388 else: 389 tf_c, tf_m, tf_v = sess.run([op_c, op_m, op_v], 390 feed_dict={x: x_val}) 391 self.assertAllClose(np_c, tf_c, atol=0.000001) 392 self.assertAllClose(np_m, tf_m, atol=0.000001) 393 self.assertAllClose(np_v, tf_v, atol=0.000001) 394 if shift: 395 self.assertAllClose(np_s, tf_s, atol=0.000001) 396 397 @test_util.run_deprecated_v1 398 def testSuffStats(self): 399 for has_shape in [True, False]: 400 for keep_dims in [True, False]: 401 for shift in [None, 1.0]: 402 self._testSuffStats([2, 3], [1], shift, keep_dims, has_shape) 403 self._testSuffStats([2, 3], [0], shift, keep_dims, has_shape) 404 self._testSuffStats([1, 2, 3], [0, 2], shift, keep_dims, has_shape) 405 406 407class NormalizeMomentsTest(test.TestCase): 408 409 def _npNormalizeMoments(self, counts, mean_ss, variance_ss, shift): 410 mean = mean_ss / counts 411 variance = variance_ss / counts - mean * mean 412 if shift is not None: 413 mean += shift 414 return mean, variance 415 416 def _opNormalizeMoments(self, counts, mean_ss, variance_ss, shift): 417 return nn_impl.normalize_moments(counts, mean_ss, variance_ss, shift) 418 419 def _testNormalizeMoments(self, shape, shift): 420 counts = np.ones([1]).astype(np.float32) 421 mean_ss = np.random.random_sample(shape).astype(np.float32) 422 variance_ss = np.random.random_sample(shape).astype(np.float32) 423 variance_ss *= variance_ss 424 if shift: 425 shift_v = np.random.random_sample(shape).astype(np.float32) 426 else: 427 shift_v = None 428 npm, npv = self._npNormalizeMoments(counts, mean_ss, variance_ss, shift_v) 429 for use_gpu in [True, False]: 430 with self.cached_session(use_gpu=use_gpu) as sess: 431 tf_counts = constant_op.constant(counts, name="counts") 432 tf_mean_ss = constant_op.constant(mean_ss, name="mean_ss") 433 tf_variance_ss = constant_op.constant(variance_ss, name="variance_ss") 434 if shift: 435 tf_shift_v = constant_op.constant(shift_v, name="shift") 436 else: 437 tf_shift_v = None 438 opm, opv = self._opNormalizeMoments(tf_counts, tf_mean_ss, 439 tf_variance_ss, tf_shift_v) 440 tfm, tfv = self.evaluate([opm, opv]) 441 self.assertAllClose(npm, tfm, atol=0.000001) 442 self.assertAllClose(npv, tfv, atol=0.000001) 443 444 def testNormalizeMoments(self): 445 for shift in [None, 4.0]: 446 self._testNormalizeMoments([3], shift) 447 self._testNormalizeMoments([2, 3], shift) 448 449 450class MomentsTest(test.TestCase): 451 452 def _unweighted_moments(self, x, axes, keep_dims=False, extra_out_grads=None): 453 # Method to compute moments of `x` wrt `axes`. 454 # 455 # This is exposed so WeightedMomentsTest can inherit the tests and 456 # assertions from MomentsTest; the extra_out_grads argument allows 457 # its inherited gradient tests to assert gradients against the 458 # weights as well as the input values. 459 460 return nn_impl.moments(x, axes, keep_dims=keep_dims) 461 462 def RunMomentTestWithDynamicShape(self, shape, axes, keep_dims, dtype): 463 with self.cached_session(): 464 # shape = [batch, width, height, depth] 465 assert len(shape) == 4 466 467 x_numpy = np.random.normal(size=shape).astype(np.float32) 468 x = array_ops.placeholder(dtype, shape=[None] * len(shape)) 469 470 mean, var = self._unweighted_moments(x, axes, keep_dims=keep_dims) 471 472 num_elements = np.prod([shape[i] for i in axes]) 473 474 ax = tuple(axes) 475 expected_mean = np.sum(x_numpy, axis=ax, 476 keepdims=keep_dims) / num_elements 477 expected_mean_squared = np.multiply(expected_mean, expected_mean) 478 expected_x_squared = np.sum(np.multiply(x_numpy, x_numpy), 479 axis=ax, 480 keepdims=keep_dims) / num_elements 481 expected_variance = expected_x_squared - expected_mean_squared 482 483 # Check that the moments are correct. 484 self.assertAllCloseAccordingToType( 485 expected_mean, mean.eval(feed_dict={x: x_numpy})) 486 self.assertAllCloseAccordingToType( 487 expected_variance, var.eval(feed_dict={x: x_numpy})) 488 489 def RunMomentTest(self, shape, axes, keep_dims, dtype): 490 with self.cached_session(): 491 # shape = [batch, width, height, depth] 492 assert len(shape) == 4 493 494 x_numpy = np.random.normal(size=shape).astype(np.float32) 495 x = math_ops.cast(constant_op.constant(x_numpy), dtype=dtype) 496 497 # Compute the expected values at high precision since the method 498 # is prone to catastrophic cancellation: 499 x_numpy = x_numpy.astype(np.float128) 500 501 mean, var = self._unweighted_moments(x, axes, keep_dims=keep_dims) 502 503 num_elements = np.prod([shape[i] for i in axes]) 504 505 ax = tuple(axes) 506 expected_mean = np.sum(x_numpy, axis=ax, 507 keepdims=keep_dims) / num_elements 508 expected_mean_squared = np.multiply(expected_mean, expected_mean) 509 expected_x_squared = np.sum(np.multiply(x_numpy, x_numpy), 510 axis=ax, 511 keepdims=keep_dims) / num_elements 512 expected_variance = expected_x_squared - expected_mean_squared 513 514 # Check that the moments are correct. 515 self.assertAllCloseAccordingToType(expected_mean, self.evaluate(mean)) 516 self.assertAllCloseAccordingToType(expected_variance, self.evaluate(var)) 517 518 @test_util.run_deprecated_v1 519 def testBasic(self): 520 for keep_dims in [False, True]: 521 for dtype in [dtypes.float32, dtypes.float16]: 522 self.RunMomentTest( 523 shape=[2, 3, 5, 4], axes=[0], keep_dims=keep_dims, dtype=dtype) 524 self.RunMomentTestWithDynamicShape( 525 shape=[2, 3, 5, 4], axes=[0], keep_dims=keep_dims, dtype=dtype) 526 527 @test_util.run_deprecated_v1 528 def testGlobalNormalization(self): 529 for keep_dims in [False, True]: 530 for dtype in [dtypes.float32, dtypes.float16]: 531 self.RunMomentTest( 532 shape=[2, 3, 5, 4], 533 axes=[0, 1, 2], 534 keep_dims=keep_dims, 535 dtype=dtype) 536 self.RunMomentTestWithDynamicShape( 537 shape=[2, 3, 5, 4], 538 axes=[0, 1, 2], 539 keep_dims=keep_dims, 540 dtype=dtype) 541 542 @test_util.run_deprecated_v1 543 def testAxes(self): 544 for keep_dims in [False, True]: 545 for dtype in [dtypes.float32, dtypes.float16]: 546 self.RunMomentTest( 547 shape=[2, 3, 5, 4], 548 axes=[1, 2, 3], 549 keep_dims=keep_dims, 550 dtype=dtype) 551 self.RunMomentTestWithDynamicShape( 552 shape=[2, 3, 5, 4], 553 axes=[1, 2, 3], 554 keep_dims=keep_dims, 555 dtype=dtype) 556 557 def _testGlobalGradient(self, from_y="mean"): 558 with self.cached_session(): 559 x_shape = [3, 5, 4, 2] 560 x_val = np.random.random_sample(x_shape).astype(np.float64) 561 x = constant_op.constant(x_val) 562 x.set_shape(x_shape) 563 564 axes = [0, 1, 2] 565 y_shape = [2] # Depth of x 566 567 inputs_to_compute_gradients_for = [x] 568 569 out_mean, out_var = self._unweighted_moments( 570 x, axes, extra_out_grads=inputs_to_compute_gradients_for) 571 if from_y == "mean": 572 y = out_mean 573 elif from_y == "var": 574 y = out_var 575 576 for (i, v) in enumerate(inputs_to_compute_gradients_for): 577 err = gradient_checker.compute_gradient_error(v, 578 v.get_shape().as_list(), 579 y, y_shape) 580 print("Moments %s gradient err vs input %d = %g" % (from_y, i, err)) 581 self.assertLess(err, 1e-11) 582 583 @test_util.run_deprecated_v1 584 def testMeanGlobalGradient(self): 585 self._testGlobalGradient(from_y="mean") 586 587 @test_util.run_deprecated_v1 588 def testVarGlobalGradient(self): 589 self._testGlobalGradient(from_y="var") 590 591 592class WeightedMomentsTest(MomentsTest): 593 """Tests for nn.weighted_moments. 594 595 Note that this test inherits from MomentsTest, inheriting all its 596 test methods! 597 598 It modifies MomentsTest in two ways: 599 600 a) By overriding _unweighted_moments, all the codepaths in 601 MomentsTest are executed, but with calls to tf.nn.moments() 602 replaced by calls to tf.nn.weighted_moments() with a constant 603 weight of 1. 604 605 b) By overriding RunMomentTest and RunMomentTestWithDynamicShape, 606 this test adds multiple additional calls to 607 RunWeightedMomentsTest() to exercise correctness with 608 non-constant weights and varying broadcasting situations. (It 609 also continues to call MomentsTest.Run(Weighted)?MomentsTest as 610 well.) 611 612 """ 613 614 def _unweighted_moments(self, x, axes, keep_dims=False, extra_out_grads=None): 615 weights = constant_op.constant(1, dtype=x.dtype) 616 if extra_out_grads is not None: 617 # We want to assert gradients WRT weights as well as X! 618 extra_out_grads.append(weights) 619 return nn_impl.weighted_moments(x, axes, weights, keep_dims=keep_dims) 620 621 def RunMomentTest(self, shape, axes, keep_dims, dtype, dynshapes=False): 622 if not dynshapes: 623 super(WeightedMomentsTest, self).RunMomentTest(shape, axes, keep_dims, 624 dtype) 625 else: 626 super(WeightedMomentsTest, self).RunMomentTestWithDynamicShape(shape, 627 axes, 628 keep_dims, 629 dtype) 630 631 # 1:1 weights and inputs 632 self.RunWeightedMomentTest(shape, shape, axes, keep_dims, dtype) 633 634 # Various broadcasting combinations 635 for idx in range(len(shape)): 636 # try broadcasting weights in all positions 637 weight_shape = [1] * len(shape) 638 weight_shape[idx] = shape[idx] 639 640 self.RunWeightedMomentTest(shape, weight_shape, axes, keep_dims, dtype) 641 642 # Also try broadcasting with a suffix of length n 643 weight_shape = shape[-(idx + 1):] 644 self.RunWeightedMomentTest( 645 shape, weight_shape, axes, keep_dims, dtype, dynshapes=dynshapes) 646 647 def RunMomentTestWithDynamicShape(self, shape, axes, keep_dims, dtype): 648 self.RunMomentTest(shape, axes, keep_dims, dtype, dynshapes=True) 649 650 def RunWeightedMomentTest(self, 651 shape, 652 weights_shape, 653 axes, 654 keep_dims, 655 dtype, 656 dynshapes=False): 657 with self.cached_session() as s: 658 x_numpy = np.random.normal(size=shape).astype(np.float32) 659 weights_numpy = np.absolute( # weights must be positive 660 np.random.normal( 661 size=weights_shape, loc=1.0).astype(np.float32)) 662 663 # Expand the numpy version to higher precision 664 x_numpy = x_numpy.astype(np.float128) 665 weights_numpy = weights_numpy.astype(np.float128) 666 667 x_shape = [None] * len(shape) if dynshapes else shape 668 weights_shape = ([None] * len(weights_shape) if dynshapes else 669 weights_shape) 670 671 x = array_ops.placeholder(dtype, shape=x_shape) 672 weights = array_ops.placeholder(dtype, shape=weights_shape) 673 674 mean, var = nn_impl.weighted_moments( 675 x, axes, weights, keep_dims=keep_dims) 676 677 ax = tuple(axes) 678 679 def _np_weighted_sum(v): 680 return np.sum(weights_numpy * v, axis=ax, keepdims=keep_dims) 681 682 weight_sum = _np_weighted_sum(np.ones_like(x_numpy)) 683 expected_mean = _np_weighted_sum(x_numpy) / weight_sum 684 expected_mean_squared = np.multiply(expected_mean, expected_mean) 685 expected_x_squared = (_np_weighted_sum(np.multiply(x_numpy, x_numpy)) / 686 weight_sum) 687 expected_variance = expected_x_squared - expected_mean_squared 688 689 mean_v, var_v = s.run([mean, var], 690 feed_dict={x: x_numpy, 691 weights: weights_numpy}) 692 693 self.assertAllCloseAccordingToType(expected_mean, mean_v) 694 self.assertAllCloseAccordingToType(expected_variance, var_v) 695 696 697if __name__ == "__main__": 698 test.main() 699