1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for batch_norm related functionality in tensorflow.ops.nn.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import numpy as np 22from six.moves import xrange # pylint: disable=redefined-builtin 23 24from tensorflow.python.framework import constant_op 25from tensorflow.python.framework import dtypes 26from tensorflow.python.framework import ops 27from tensorflow.python.framework import test_util 28from tensorflow.python.ops import array_ops 29from tensorflow.python.ops import gen_nn_ops 30from tensorflow.python.ops import gradient_checker 31from tensorflow.python.ops import gradients_impl 32from tensorflow.python.ops import math_ops 33from tensorflow.python.ops import nn_impl 34import tensorflow.python.ops.nn_grad # pylint: disable=unused-import 35from tensorflow.python.platform import test 36 37 38class BatchNormalizationTest(test.TestCase): 39 40 def _npBatchNorm(self, x, m, v, beta, gamma, epsilon, 41 scale_after_normalization, shift_after_normalization): 42 y = (x - m) / np.sqrt(v + epsilon) 43 y = y * gamma if scale_after_normalization else y 44 return y + beta if shift_after_normalization else y 45 46 def _opsBatchNorm(self, x, m, v, beta, gamma, epsilon, 47 scale_after_normalization, shift_after_normalization): 48 y = (x - m) * math_ops.rsqrt(v + epsilon) 49 if scale_after_normalization: 50 y = gamma * y 51 return y + beta if shift_after_normalization else y 52 53 def _tfBatchNormV1(self, x, m, v, beta, gamma, epsilon, 54 scale_after_normalization): 55 """Original implementation.""" 56 test_util.set_producer_version(ops.get_default_graph(), 8) 57 return gen_nn_ops._batch_norm_with_global_normalization( 58 x, m, v, beta, gamma, epsilon, scale_after_normalization) 59 60 def _tfBatchNormV1BW(self, x, m, v, beta, gamma, epsilon, 61 scale_after_normalization): 62 """Re-implementation of the original kernel for backward compatibility.""" 63 return nn_impl.batch_norm_with_global_normalization( 64 x, m, v, beta, gamma, epsilon, scale_after_normalization) 65 66 def _tfBatchNormV2(self, x, m, v, beta, gamma, epsilon, 67 scale_after_normalization, shift_after_normalization): 68 """New implementation.""" 69 return nn_impl.batch_normalization(x, m, v, beta if 70 shift_after_normalization else None, 71 gamma if scale_after_normalization else 72 None, epsilon) 73 74 @test_util.run_deprecated_v1 75 def testBatchNorm(self): 76 x_shape = [3, 5, 4, 2] 77 param_shape = [2] 78 x_val = np.random.random_sample(x_shape).astype(np.float32) 79 m_val = np.random.random_sample(param_shape).astype(np.float32) 80 v_val = np.random.random_sample(param_shape).astype(np.float32) 81 beta_val = np.random.random_sample(param_shape).astype(np.float32) 82 gamma_val = np.random.random_sample(param_shape).astype(np.float32) 83 for use_gpu in [True, False]: 84 with self.cached_session(use_gpu=use_gpu) as sess: 85 x = constant_op.constant(x_val, name="x") 86 m = constant_op.constant(m_val, name="m") 87 v = constant_op.constant(v_val, name="v") 88 beta = constant_op.constant(beta_val, name="beta") 89 gamma = constant_op.constant(gamma_val, name="gamma") 90 epsilon = 0.001 91 for scale_after_normalization in [True, False]: 92 for shift_after_normalization in [True, False]: 93 bn2 = self._tfBatchNormV2(x, m, v, beta, gamma, epsilon, 94 scale_after_normalization, 95 shift_after_normalization) 96 bn1bw = self._tfBatchNormV1BW(x, m, v, beta, gamma, epsilon, 97 scale_after_normalization) 98 bn1 = self._tfBatchNormV1(x, m, v, beta, gamma, epsilon, 99 scale_after_normalization) 100 on = self._opsBatchNorm(x, m, v, beta, gamma, epsilon, 101 scale_after_normalization, 102 shift_after_normalization) 103 np_bn = self._npBatchNorm(x_val, m_val, v_val, beta_val, gamma_val, 104 epsilon, scale_after_normalization, 105 shift_after_normalization) 106 tf_bn_v2, tf_bn_v1bw, tf_bn_v1, ops_bn = sess.run( 107 [bn2, bn1bw, bn1, on]) 108 self.assertAllClose(np_bn, ops_bn, atol=0.00001) 109 self.assertAllClose(np_bn, tf_bn_v2, atol=0.00001) 110 self.assertAllClose(tf_bn_v2, ops_bn, atol=0.00001) 111 # shift_after_normalization=False is not supported in v1. 112 if shift_after_normalization: 113 self.assertAllClose(np_bn, tf_bn_v1bw, atol=0.00001) 114 self.assertAllClose(np_bn, tf_bn_v1, atol=0.00001) 115 self.assertAllClose(tf_bn_v1, ops_bn, atol=0.00001) 116 self.assertAllClose(tf_bn_v1bw, ops_bn, atol=0.00001) 117 118 def _testBatchNormGradient(self, 119 param_index, 120 tag, 121 scale_after_normalization, 122 shift_after_normalization, 123 version, 124 err_tolerance=1e-11): 125 x_shape = [3, 5, 4, 5] 126 param_shape = [5] 127 np.random.seed(1) # Make it reproducible. 128 x_val = np.random.random_sample(x_shape).astype(np.float64) 129 m_val = np.random.random_sample(param_shape).astype(np.float64) 130 v_val = np.random.random_sample(param_shape).astype(np.float64) 131 beta_val = np.random.random_sample(param_shape).astype(np.float64) 132 gamma_val = np.random.random_sample(param_shape).astype(np.float64) 133 with self.cached_session(): 134 x = constant_op.constant(x_val, name="x") 135 m = constant_op.constant(m_val, name="m") 136 v = constant_op.constant(v_val, name="v") 137 beta = constant_op.constant(beta_val, name="beta") 138 gamma = constant_op.constant(gamma_val, name="gamma") 139 epsilon = 0.001 140 if version == 1: 141 output = self._tfBatchNormV1(x, m, v, beta, gamma, epsilon, 142 scale_after_normalization) 143 elif version == 2: 144 output = self._tfBatchNormV2(x, m, v, beta, gamma, epsilon, 145 scale_after_normalization, 146 shift_after_normalization) 147 else: 148 print("Invalid version", version) 149 raise ValueError() 150 all_params = [x, m, v, beta, gamma] 151 all_shapes = [x_shape, param_shape, param_shape, param_shape, param_shape] 152 err = gradient_checker.compute_gradient_error(all_params[param_index], 153 all_shapes[param_index], 154 output, x_shape) 155 print("Batch normalization v%d %s gradient %s scale and %s shift err = " % 156 (version, tag, "with" if scale_after_normalization else "without", 157 "with" if shift_after_normalization else "without"), err) 158 self.assertLess(err, err_tolerance) 159 160 def _testBatchNormGradientInAllNeedConfigs(self, 161 param_index, 162 tag, 163 err_tolerance=1e-11): 164 for scale_after_normalization in [True, False]: 165 for shift_after_normalization in [True, False]: 166 # shift_after_normalization=False is not supported in version 1. 167 for v in ([1, 2] if shift_after_normalization else [2]): 168 self._testBatchNormGradient(param_index, tag, 169 scale_after_normalization, 170 shift_after_normalization, v, 171 err_tolerance) 172 173 @test_util.run_deprecated_v1 174 def testBatchNormInputGradient(self): 175 self._testBatchNormGradientInAllNeedConfigs(0, "x") 176 177 @test_util.run_deprecated_v1 178 def testBatchNormMeanGradient(self): 179 self._testBatchNormGradientInAllNeedConfigs(1, "mean") 180 181 @test_util.run_deprecated_v1 182 def testBatchNormVarianceGradient(self): 183 self._testBatchNormGradientInAllNeedConfigs( 184 2, "variance", err_tolerance=1e-03) 185 186 @test_util.run_deprecated_v1 187 def testBatchNormBetaGradient(self): 188 # Since beta does not exist when scale_after_normalization=False, we only 189 # test for scale_after_normalization=True. 190 for scale_after_normalization in [True, False]: 191 for v in [1, 2]: 192 self._testBatchNormGradient(3, "beta", scale_after_normalization, True, 193 v) 194 195 @test_util.run_deprecated_v1 196 def testBatchNormGammaGradient(self): 197 # If scale_after_normalization is False, backprop for gamma in v1 198 # will be 0. In version 2 of the API, if scale_after_normalization is False, 199 # gamma is not used at all, and the gradient is None, which displeases the 200 # gradient checker. 201 for scale_after_normalization in [True, False]: 202 self._testBatchNormGradient(4, "gamma", scale_after_normalization, True, 203 1) 204 for shift_after_normalization in [True, False]: 205 self._testBatchNormGradient(4, "gamma", True, shift_after_normalization, 206 2) 207 208 @test_util.run_deprecated_v1 209 def testBatchNormGradImpl(self): 210 x_shape = [7, 5, 4, 6] 211 param_shape = [6] 212 np.random.seed(1) # Make it reproducible. 213 x_val = np.random.random_sample(x_shape).astype(np.float32) 214 m_val = np.random.random_sample(param_shape).astype(np.float32) 215 v_val = np.random.random_sample(param_shape).astype(np.float32) 216 beta_val = np.random.random_sample(param_shape).astype(np.float32) 217 gamma_val = np.random.random_sample(param_shape).astype(np.float32) 218 backprop_val = np.random.random_sample(x_shape).astype(np.float32) 219 for use_gpu in [False, True]: 220 with self.cached_session(use_gpu=use_gpu) as sess: 221 x = constant_op.constant(x_val, name="x") 222 m = constant_op.constant(m_val, name="m") 223 v = constant_op.constant(v_val, name="v") 224 beta = constant_op.constant(beta_val, name="beta") 225 gamma = constant_op.constant(gamma_val, name="gamma") 226 backprop = constant_op.constant(backprop_val, name="backprop") 227 epsilon = 0.001 228 for scale_after_normalization in [True, False]: 229 # _batch_norm_with_global_normalization_grad is deprecated in v9 230 test_util.set_producer_version(ops.get_default_graph(), 8) 231 grad = gen_nn_ops.batch_norm_with_global_normalization_grad( 232 x, m, v, gamma, backprop, epsilon, scale_after_normalization) 233 dx, dm, dv, db, dg = grad 234 self.assertEqual(grad.dx, dx) 235 self.assertEqual(grad.dm, dm) 236 self.assertEqual(grad.dv, dv) 237 self.assertEqual(grad.db, db) 238 self.assertEqual(grad.dg, dg) 239 240 on = self._opsBatchNorm(x, m, v, beta, gamma, epsilon, 241 scale_after_normalization, True) 242 odx, odm, odv, odb, odg = gradients_impl.gradients( 243 [on], [x, m, v, beta, gamma], [backprop]) 244 if scale_after_normalization: 245 all_grads = self.evaluate( 246 [dx, dm, dv, db, dg, odx, odm, odv, odb, odg]) 247 to_check = ["dx", "dm", "dv", "db", "dg"] 248 else: 249 all_grads = self.evaluate([dx, dm, dv, db, odx, odm, odv, odb]) 250 to_check = ["dx", "dm", "dv", "db"] 251 for i, _ in enumerate(to_check): 252 self.assertAllClose( 253 all_grads[i + len(to_check)], all_grads[i], atol=0.000001) 254 255 @test_util.run_deprecated_v1 256 def testBatchNormKeepDims(self): 257 """Test for tf.nn.moments(..., keep_dims=True / False). 258 259 Make sure that parameters with shape (1, 1, 1, depth) yield the same 260 result as parameters with shape (depth) 261 """ 262 x_shape = (3, 5, 4, 2) 263 param_shape = (2) 264 keep_dims_param_shape = (1, 1, 1, 2) 265 x_val = np.random.random_sample(x_shape).astype(np.float32) 266 m_val = np.random.random_sample(param_shape).astype(np.float32) 267 v_val = np.random.random_sample(param_shape).astype(np.float32) 268 beta_val = np.random.random_sample(param_shape).astype(np.float32) 269 gamma_val = np.random.random_sample(param_shape).astype(np.float32) 270 for use_gpu in [True, False]: 271 with self.cached_session(use_gpu=use_gpu) as sess: 272 x = constant_op.constant(x_val, name="x") 273 m = constant_op.constant(m_val, name="m") 274 v = constant_op.constant(v_val, name="v") 275 beta = constant_op.constant(beta_val, name="beta") 276 gamma = constant_op.constant(gamma_val, name="gamma") 277 keep_dims_m = array_ops.reshape( 278 m, keep_dims_param_shape, name="keep_dims_m") 279 keep_dims_v = array_ops.reshape( 280 v, keep_dims_param_shape, name="keep_dims_v") 281 keep_dims_beta = array_ops.reshape( 282 beta, keep_dims_param_shape, name="keep_dims_beta") 283 keep_dims_gamma = array_ops.reshape( 284 gamma, keep_dims_param_shape, name="keep_dims_gamma") 285 epsilon = 0.001 286 for scale_after_normalization in [True, False]: 287 for shift_after_normalization in [True, False]: 288 bn = self._tfBatchNormV2(x, m, v, beta, gamma, epsilon, 289 scale_after_normalization, 290 shift_after_normalization) 291 keep_dims_bn = self._tfBatchNormV2(x, keep_dims_m, keep_dims_v, 292 keep_dims_beta, keep_dims_gamma, 293 epsilon, 294 scale_after_normalization, 295 shift_after_normalization) 296 tf_batch_norm, keep_dims_tf_batch_norm = sess.run( 297 [bn, keep_dims_bn]) 298 self.assertEquals(x_shape, tf_batch_norm.shape) 299 self.assertEquals(x_shape, keep_dims_tf_batch_norm.shape) 300 self.assertAllClose( 301 tf_batch_norm, keep_dims_tf_batch_norm, atol=0.000001) 302 303 def _testBatchNormArbitraryShapes(self, x_shape, param_shape, atol=0.0001, 304 dtype=dtypes.float32, 305 param_dtype=dtypes.float32): 306 numpy_dtype = dtype.as_numpy_dtype 307 numpy_param_dtype = param_dtype.as_numpy_dtype 308 x_val = np.random.random_sample(x_shape).astype(numpy_dtype) 309 m_val = np.random.random_sample(param_shape).astype(numpy_param_dtype) 310 v_val = np.random.random_sample(param_shape).astype(numpy_param_dtype) 311 beta_val = np.random.random_sample(param_shape).astype(numpy_param_dtype) 312 gamma_val = np.random.random_sample(param_shape).astype(numpy_param_dtype) 313 for use_gpu in [True, False]: 314 with self.cached_session(use_gpu=use_gpu) as sess: 315 x = constant_op.constant(x_val, name="x") 316 m = constant_op.constant(m_val, name="m") 317 v = constant_op.constant(v_val, name="v") 318 beta = constant_op.constant(beta_val, name="beta") 319 gamma = constant_op.constant(gamma_val, name="gamma") 320 epsilon = 0.001 321 for scale_after_normalization in [True, False]: 322 for shift_after_normalization in [True, False]: 323 bn = self._tfBatchNormV2(x, m, v, beta, gamma, epsilon, 324 scale_after_normalization, 325 shift_after_normalization) 326 np_batch_norm = self._npBatchNorm(x_val, m_val, v_val, beta_val, 327 gamma_val, epsilon, 328 scale_after_normalization, 329 shift_after_normalization) 330 [tf_batch_norm] = self.evaluate([bn]) 331 self.assertEquals(x_shape, np_batch_norm.shape) 332 self.assertEquals(x_shape, tf_batch_norm.shape) 333 self.assertAllClose(np_batch_norm, tf_batch_norm, atol=atol) 334 335 def testBatchNormArbitraryShapes(self): 336 """Test for a variety of shapes and moments. 337 338 Batch normalization is expected to work regardless of the position and 339 dimensionality of the 'depth' axis/axes. 340 """ 341 self._testBatchNormArbitraryShapes((3, 3), (1, 3)) 342 self._testBatchNormArbitraryShapes((3, 3), (3, 1)) 343 self._testBatchNormArbitraryShapes((3, 2, 4, 5), (1, 2, 1, 1)) 344 self._testBatchNormArbitraryShapes( 345 (2, 3, 2, 4, 5), (1, 1, 1, 4, 5), atol=0.005) 346 347 def testBatchNormMixedPrecision(self): 348 self._testBatchNormArbitraryShapes((3, 3), (1, 3), dtype=dtypes.float16, 349 param_dtype=dtypes.float32, atol=0.001) 350 351 352class SufficientStatisticsTest(test.TestCase): 353 354 def _npSuffStats(self, x, axes, shift, keep_dims): 355 axis = tuple(axes) 356 if shift is not None: 357 m_ss = np.sum(x - shift, axis=axis, keepdims=keep_dims) 358 v_ss = np.sum((x - shift) * (x - shift), axis=axis, keepdims=keep_dims) 359 else: 360 m_ss = np.sum(x, axis=axis, keepdims=keep_dims) 361 v_ss = np.sum(x * x, axis=axis, keepdims=keep_dims) 362 count = 1.0 363 for d in xrange(x.ndim): 364 if d in set(axes): 365 count *= x.shape[d] 366 if not keep_dims: 367 shift = np.squeeze(shift, axis=axis) 368 return count, m_ss, v_ss, shift 369 370 def _opSuffStats(self, x, axes, shift, keep_dims): 371 return nn_impl.sufficient_statistics(x, axes, shift, keep_dims) 372 373 def _testSuffStats(self, x_shape, axes, shift, keep_dims, has_shape): 374 x_val = np.random.random_sample(x_shape).astype(np.float32) 375 np_c, np_m, np_v, np_s = self._npSuffStats(x_val, axes, shift, keep_dims) 376 for use_gpu in [True, False]: 377 with self.cached_session(use_gpu=use_gpu) as sess: 378 if has_shape: 379 x = constant_op.constant(x_val, name="x") 380 x.set_shape(x_shape) 381 op_c, op_m, op_v, op_s = self._opSuffStats(x, axes, shift, keep_dims) 382 if shift: 383 tf_c, tf_m, tf_v, tf_s = self.evaluate([op_c, op_m, op_v, op_s]) 384 else: 385 tf_c, tf_m, tf_v = self.evaluate([op_c, op_m, op_v]) 386 else: 387 x = array_ops.placeholder( 388 dtype=dtypes.float32, shape=[None] * len(x_shape), name="x") 389 op_c, op_m, op_v, op_s = self._opSuffStats(x, axes, shift, keep_dims) 390 if shift: 391 tf_c, tf_m, tf_v, tf_s = sess.run([op_c, op_m, op_v, op_s], 392 feed_dict={x: x_val}) 393 else: 394 tf_c, tf_m, tf_v = sess.run([op_c, op_m, op_v], 395 feed_dict={x: x_val}) 396 self.assertAllClose(np_c, tf_c, atol=0.000001) 397 self.assertAllClose(np_m, tf_m, atol=0.000001) 398 self.assertAllClose(np_v, tf_v, atol=0.000001) 399 if shift: 400 self.assertAllClose(np_s, tf_s, atol=0.000001) 401 402 @test_util.run_deprecated_v1 403 def testSuffStats(self): 404 for has_shape in [True, False]: 405 for keep_dims in [True, False]: 406 for shift in [None, 1.0]: 407 self._testSuffStats([2, 3], [1], shift, keep_dims, has_shape) 408 self._testSuffStats([2, 3], [0], shift, keep_dims, has_shape) 409 self._testSuffStats([1, 2, 3], [0, 2], shift, keep_dims, has_shape) 410 411 412class NormalizeMomentsTest(test.TestCase): 413 414 def _npNormalizeMoments(self, counts, mean_ss, variance_ss, shift): 415 mean = mean_ss / counts 416 variance = variance_ss / counts - mean * mean 417 if shift is not None: 418 mean += shift 419 return mean, variance 420 421 def _opNormalizeMoments(self, counts, mean_ss, variance_ss, shift): 422 return nn_impl.normalize_moments(counts, mean_ss, variance_ss, shift) 423 424 def _testNormalizeMoments(self, shape, shift): 425 counts = np.ones([1]).astype(np.float32) 426 mean_ss = np.random.random_sample(shape).astype(np.float32) 427 variance_ss = np.random.random_sample(shape).astype(np.float32) 428 variance_ss *= variance_ss 429 if shift: 430 shift_v = np.random.random_sample(shape).astype(np.float32) 431 else: 432 shift_v = None 433 npm, npv = self._npNormalizeMoments(counts, mean_ss, variance_ss, shift_v) 434 for use_gpu in [True, False]: 435 with self.cached_session(use_gpu=use_gpu) as sess: 436 tf_counts = constant_op.constant(counts, name="counts") 437 tf_mean_ss = constant_op.constant(mean_ss, name="mean_ss") 438 tf_variance_ss = constant_op.constant(variance_ss, name="variance_ss") 439 if shift: 440 tf_shift_v = constant_op.constant(shift_v, name="shift") 441 else: 442 tf_shift_v = None 443 opm, opv = self._opNormalizeMoments(tf_counts, tf_mean_ss, 444 tf_variance_ss, tf_shift_v) 445 tfm, tfv = self.evaluate([opm, opv]) 446 self.assertAllClose(npm, tfm, atol=0.000001) 447 self.assertAllClose(npv, tfv, atol=0.000001) 448 449 def testNormalizeMoments(self): 450 for shift in [None, 4.0]: 451 self._testNormalizeMoments([3], shift) 452 self._testNormalizeMoments([2, 3], shift) 453 454 455class MomentsTest(test.TestCase): 456 457 def _unweighted_moments(self, x, axes, keep_dims=False, extra_out_grads=None): 458 # Method to compute moments of `x` wrt `axes`. 459 # 460 # This is exposed so WeightedMomentsTest can inherit the tests and 461 # assertions from MomentsTest; the extra_out_grads argument allows 462 # its inherited gradient tests to assert gradients against the 463 # weights as well as the input values. 464 465 return nn_impl.moments(x, axes, keep_dims=keep_dims) 466 467 def RunMomentTestWithDynamicShape(self, shape, axes, keep_dims, dtype): 468 with self.cached_session(): 469 # shape = [batch, width, height, depth] 470 assert len(shape) == 4 471 472 x_numpy = np.random.normal(size=shape).astype(np.float32) 473 x = array_ops.placeholder(dtype, shape=[None] * len(shape)) 474 475 mean, var = self._unweighted_moments(x, axes, keep_dims=keep_dims) 476 477 num_elements = np.prod([shape[i] for i in axes]) 478 479 ax = tuple(axes) 480 expected_mean = np.sum(x_numpy, axis=ax, 481 keepdims=keep_dims) / num_elements 482 expected_mean_squared = np.multiply(expected_mean, expected_mean) 483 expected_x_squared = np.sum(np.multiply(x_numpy, x_numpy), 484 axis=ax, 485 keepdims=keep_dims) / num_elements 486 expected_variance = expected_x_squared - expected_mean_squared 487 488 # Check that the moments are correct. 489 self.assertAllCloseAccordingToType( 490 expected_mean, mean.eval(feed_dict={x: x_numpy})) 491 self.assertAllCloseAccordingToType( 492 expected_variance, var.eval(feed_dict={x: x_numpy})) 493 494 def RunMomentTest(self, shape, axes, keep_dims, dtype): 495 with self.cached_session(): 496 # shape = [batch, width, height, depth] 497 assert len(shape) == 4 498 499 x_numpy = np.random.normal(size=shape).astype(np.float32) 500 x = math_ops.cast(constant_op.constant(x_numpy), dtype=dtype) 501 502 # Compute the expected values at high precision since the method 503 # is prone to catastrophic cancellation: 504 x_numpy = x_numpy.astype(np.float128) 505 506 mean, var = self._unweighted_moments(x, axes, keep_dims=keep_dims) 507 508 num_elements = np.prod([shape[i] for i in axes]) 509 510 ax = tuple(axes) 511 expected_mean = np.sum(x_numpy, axis=ax, 512 keepdims=keep_dims) / num_elements 513 expected_mean_squared = np.multiply(expected_mean, expected_mean) 514 expected_x_squared = np.sum(np.multiply(x_numpy, x_numpy), 515 axis=ax, 516 keepdims=keep_dims) / num_elements 517 expected_variance = expected_x_squared - expected_mean_squared 518 519 # Check that the moments are correct. 520 self.assertAllCloseAccordingToType(expected_mean, self.evaluate(mean)) 521 self.assertAllCloseAccordingToType(expected_variance, self.evaluate(var)) 522 523 @test_util.run_deprecated_v1 524 def testBasic(self): 525 for keep_dims in [False, True]: 526 for dtype in [dtypes.float32, dtypes.float16]: 527 self.RunMomentTest( 528 shape=[2, 3, 5, 4], axes=[0], keep_dims=keep_dims, dtype=dtype) 529 self.RunMomentTestWithDynamicShape( 530 shape=[2, 3, 5, 4], axes=[0], keep_dims=keep_dims, dtype=dtype) 531 532 @test_util.run_deprecated_v1 533 def testGlobalNormalization(self): 534 for keep_dims in [False, True]: 535 for dtype in [dtypes.float32, dtypes.float16]: 536 self.RunMomentTest( 537 shape=[2, 3, 5, 4], 538 axes=[0, 1, 2], 539 keep_dims=keep_dims, 540 dtype=dtype) 541 self.RunMomentTestWithDynamicShape( 542 shape=[2, 3, 5, 4], 543 axes=[0, 1, 2], 544 keep_dims=keep_dims, 545 dtype=dtype) 546 547 @test_util.run_deprecated_v1 548 def testAxes(self): 549 for keep_dims in [False, True]: 550 for dtype in [dtypes.float32, dtypes.float16]: 551 self.RunMomentTest( 552 shape=[2, 3, 5, 4], 553 axes=[1, 2, 3], 554 keep_dims=keep_dims, 555 dtype=dtype) 556 self.RunMomentTestWithDynamicShape( 557 shape=[2, 3, 5, 4], 558 axes=[1, 2, 3], 559 keep_dims=keep_dims, 560 dtype=dtype) 561 562 def _testGlobalGradient(self, from_y="mean"): 563 with self.cached_session(): 564 x_shape = [3, 5, 4, 2] 565 x_val = np.random.random_sample(x_shape).astype(np.float64) 566 x = constant_op.constant(x_val) 567 x.set_shape(x_shape) 568 569 axes = [0, 1, 2] 570 y_shape = [2] # Depth of x 571 572 inputs_to_compute_gradients_for = [x] 573 574 out_mean, out_var = self._unweighted_moments( 575 x, axes, extra_out_grads=inputs_to_compute_gradients_for) 576 if from_y == "mean": 577 y = out_mean 578 elif from_y == "var": 579 y = out_var 580 581 for (i, v) in enumerate(inputs_to_compute_gradients_for): 582 err = gradient_checker.compute_gradient_error(v, 583 v.get_shape().as_list(), 584 y, y_shape) 585 print("Moments %s gradient err vs input %d = %g" % (from_y, i, err)) 586 self.assertLess(err, 1e-11) 587 588 @test_util.run_deprecated_v1 589 def testMeanGlobalGradient(self): 590 self._testGlobalGradient(from_y="mean") 591 592 @test_util.run_deprecated_v1 593 def testVarGlobalGradient(self): 594 self._testGlobalGradient(from_y="var") 595 596 597class WeightedMomentsTest(MomentsTest): 598 """Tests for nn.weighted_moments. 599 600 Note that this test inherits from MomentsTest, inheriting all its 601 test methods! 602 603 It modifies MomentsTest in two ways: 604 605 a) By overriding _unweighted_moments, all the codepaths in 606 MomentsTest are executed, but with calls to tf.nn.moments() 607 replaced by calls to tf.nn.weighted_moments() with a constant 608 weight of 1. 609 610 b) By overriding RunMomentTest and RunMomentTestWithDynamicShape, 611 this test adds multiple additional calls to 612 RunWeightedMomentsTest() to exercise correctness with 613 non-constant weights and varying broadcasting situations. (It 614 also continues to call MomentsTest.Run(Weighted)?MomentsTest as 615 well.) 616 617 """ 618 619 def _unweighted_moments(self, x, axes, keep_dims=False, extra_out_grads=None): 620 weights = constant_op.constant(1, dtype=x.dtype) 621 if extra_out_grads is not None: 622 # We want to assert gradients WRT weights as well as X! 623 extra_out_grads.append(weights) 624 return nn_impl.weighted_moments(x, axes, weights, keep_dims=keep_dims) 625 626 def RunMomentTest(self, shape, axes, keep_dims, dtype, dynshapes=False): 627 if not dynshapes: 628 super(WeightedMomentsTest, self).RunMomentTest(shape, axes, keep_dims, 629 dtype) 630 else: 631 super(WeightedMomentsTest, self).RunMomentTestWithDynamicShape(shape, 632 axes, 633 keep_dims, 634 dtype) 635 636 # 1:1 weights and inputs 637 self.RunWeightedMomentTest(shape, shape, axes, keep_dims, dtype) 638 639 # Various broadcasting combinations 640 for idx in range(len(shape)): 641 # try broadcasting weights in all positions 642 weight_shape = [1] * len(shape) 643 weight_shape[idx] = shape[idx] 644 645 self.RunWeightedMomentTest(shape, weight_shape, axes, keep_dims, dtype) 646 647 # Also try broadcasting with a suffix of length n 648 weight_shape = shape[-(idx + 1):] 649 self.RunWeightedMomentTest( 650 shape, weight_shape, axes, keep_dims, dtype, dynshapes=dynshapes) 651 652 def RunMomentTestWithDynamicShape(self, shape, axes, keep_dims, dtype): 653 self.RunMomentTest(shape, axes, keep_dims, dtype, dynshapes=True) 654 655 def RunWeightedMomentTest(self, 656 shape, 657 weights_shape, 658 axes, 659 keep_dims, 660 dtype, 661 dynshapes=False): 662 with self.cached_session() as s: 663 x_numpy = np.random.normal(size=shape).astype(np.float32) 664 weights_numpy = np.absolute( # weights must be positive 665 np.random.normal( 666 size=weights_shape, loc=1.0).astype(np.float32)) 667 668 # Expand the numpy version to higher precision 669 x_numpy = x_numpy.astype(np.float128) 670 weights_numpy = weights_numpy.astype(np.float128) 671 672 x_shape = [None] * len(shape) if dynshapes else shape 673 weights_shape = ([None] * len(weights_shape) if dynshapes else 674 weights_shape) 675 676 x = array_ops.placeholder(dtype, shape=x_shape) 677 weights = array_ops.placeholder(dtype, shape=weights_shape) 678 679 mean, var = nn_impl.weighted_moments( 680 x, axes, weights, keep_dims=keep_dims) 681 682 ax = tuple(axes) 683 684 def _np_weighted_sum(v): 685 return np.sum(weights_numpy * v, axis=ax, keepdims=keep_dims) 686 687 weight_sum = _np_weighted_sum(np.ones_like(x_numpy)) 688 expected_mean = _np_weighted_sum(x_numpy) / weight_sum 689 expected_mean_squared = np.multiply(expected_mean, expected_mean) 690 expected_x_squared = (_np_weighted_sum(np.multiply(x_numpy, x_numpy)) / 691 weight_sum) 692 expected_variance = expected_x_squared - expected_mean_squared 693 694 mean_v, var_v = s.run([mean, var], 695 feed_dict={x: x_numpy, 696 weights: weights_numpy}) 697 698 self.assertAllCloseAccordingToType(expected_mean, mean_v) 699 self.assertAllCloseAccordingToType(expected_variance, var_v) 700 701 702if __name__ == "__main__": 703 test.main() 704