• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for batch_norm related functionality in tensorflow.ops.nn."""
16
17import numpy as np
18
19from tensorflow.python.framework import constant_op
20from tensorflow.python.framework import dtypes
21from tensorflow.python.framework import ops
22from tensorflow.python.framework import test_util
23from tensorflow.python.ops import array_ops
24from tensorflow.python.ops import gen_nn_ops
25from tensorflow.python.ops import gradient_checker
26from tensorflow.python.ops import gradients_impl
27from tensorflow.python.ops import math_ops
28from tensorflow.python.ops import nn_impl
29import tensorflow.python.ops.nn_grad  # pylint: disable=unused-import
30from tensorflow.python.platform import test
31
32
33class BatchNormalizationTest(test.TestCase):
34
35  def _npBatchNorm(self, x, m, v, beta, gamma, epsilon,
36                   scale_after_normalization, shift_after_normalization):
37    y = (x - m) / np.sqrt(v + epsilon)
38    y = y * gamma if scale_after_normalization else y
39    return y + beta if shift_after_normalization else y
40
41  def _opsBatchNorm(self, x, m, v, beta, gamma, epsilon,
42                    scale_after_normalization, shift_after_normalization):
43    y = (x - m) * math_ops.rsqrt(v + epsilon)
44    if scale_after_normalization:
45      y = gamma * y
46    return y + beta if shift_after_normalization else y
47
48  def _tfBatchNormV1(self, x, m, v, beta, gamma, epsilon,
49                     scale_after_normalization):
50    """Original implementation."""
51    test_util.set_producer_version(ops.get_default_graph(), 8)
52    return gen_nn_ops._batch_norm_with_global_normalization(
53        x, m, v, beta, gamma, epsilon, scale_after_normalization)
54
55  def _tfBatchNormV1BW(self, x, m, v, beta, gamma, epsilon,
56                       scale_after_normalization):
57    """Re-implementation of the original kernel for backward compatibility."""
58    return nn_impl.batch_norm_with_global_normalization(
59        x, m, v, beta, gamma, epsilon, scale_after_normalization)
60
61  def _tfBatchNormV2(self, x, m, v, beta, gamma, epsilon,
62                     scale_after_normalization, shift_after_normalization):
63    """New implementation."""
64    return nn_impl.batch_normalization(x, m, v, beta if
65                                       shift_after_normalization else None,
66                                       gamma if scale_after_normalization else
67                                       None, epsilon)
68
69  @test_util.run_deprecated_v1
70  def testBatchNorm(self):
71    x_shape = [3, 5, 4, 2]
72    param_shape = [2]
73    x_val = np.random.random_sample(x_shape).astype(np.float32)
74    m_val = np.random.random_sample(param_shape).astype(np.float32)
75    v_val = np.random.random_sample(param_shape).astype(np.float32)
76    beta_val = np.random.random_sample(param_shape).astype(np.float32)
77    gamma_val = np.random.random_sample(param_shape).astype(np.float32)
78    for use_gpu in [True, False]:
79      with self.cached_session(use_gpu=use_gpu) as sess:
80        x = constant_op.constant(x_val, name="x")
81        m = constant_op.constant(m_val, name="m")
82        v = constant_op.constant(v_val, name="v")
83        beta = constant_op.constant(beta_val, name="beta")
84        gamma = constant_op.constant(gamma_val, name="gamma")
85        epsilon = 0.001
86        for scale_after_normalization in [True, False]:
87          for shift_after_normalization in [True, False]:
88            bn2 = self._tfBatchNormV2(x, m, v, beta, gamma, epsilon,
89                                      scale_after_normalization,
90                                      shift_after_normalization)
91            bn1bw = self._tfBatchNormV1BW(x, m, v, beta, gamma, epsilon,
92                                          scale_after_normalization)
93            bn1 = self._tfBatchNormV1(x, m, v, beta, gamma, epsilon,
94                                      scale_after_normalization)
95            on = self._opsBatchNorm(x, m, v, beta, gamma, epsilon,
96                                    scale_after_normalization,
97                                    shift_after_normalization)
98            np_bn = self._npBatchNorm(x_val, m_val, v_val, beta_val, gamma_val,
99                                      epsilon, scale_after_normalization,
100                                      shift_after_normalization)
101            tf_bn_v2, tf_bn_v1bw, tf_bn_v1, ops_bn = sess.run(
102                [bn2, bn1bw, bn1, on])
103            self.assertAllClose(np_bn, ops_bn, atol=0.00001)
104            self.assertAllClose(np_bn, tf_bn_v2, atol=0.00001)
105            self.assertAllClose(tf_bn_v2, ops_bn, atol=0.00001)
106            # shift_after_normalization=False is not supported in v1.
107            if shift_after_normalization:
108              self.assertAllClose(np_bn, tf_bn_v1bw, atol=0.00001)
109              self.assertAllClose(np_bn, tf_bn_v1, atol=0.00001)
110              self.assertAllClose(tf_bn_v1, ops_bn, atol=0.00001)
111              self.assertAllClose(tf_bn_v1bw, ops_bn, atol=0.00001)
112
113  def _testBatchNormGradient(self,
114                             param_index,
115                             tag,
116                             scale_after_normalization,
117                             shift_after_normalization,
118                             version,
119                             err_tolerance=1e-11):
120    x_shape = [3, 5, 4, 5]
121    param_shape = [5]
122    np.random.seed(1)  # Make it reproducible.
123    x_val = np.random.random_sample(x_shape).astype(np.float64)
124    m_val = np.random.random_sample(param_shape).astype(np.float64)
125    v_val = np.random.random_sample(param_shape).astype(np.float64)
126    beta_val = np.random.random_sample(param_shape).astype(np.float64)
127    gamma_val = np.random.random_sample(param_shape).astype(np.float64)
128    with self.cached_session():
129      x = constant_op.constant(x_val, name="x")
130      m = constant_op.constant(m_val, name="m")
131      v = constant_op.constant(v_val, name="v")
132      beta = constant_op.constant(beta_val, name="beta")
133      gamma = constant_op.constant(gamma_val, name="gamma")
134      epsilon = 0.001
135      if version == 1:
136        output = self._tfBatchNormV1(x, m, v, beta, gamma, epsilon,
137                                     scale_after_normalization)
138      elif version == 2:
139        output = self._tfBatchNormV2(x, m, v, beta, gamma, epsilon,
140                                     scale_after_normalization,
141                                     shift_after_normalization)
142      else:
143        print("Invalid version", version)
144        raise ValueError()
145      all_params = [x, m, v, beta, gamma]
146      all_shapes = [x_shape, param_shape, param_shape, param_shape, param_shape]
147      err = gradient_checker.compute_gradient_error(all_params[param_index],
148                                                    all_shapes[param_index],
149                                                    output, x_shape)
150    print("Batch normalization v%d %s gradient %s scale and %s shift err = " %
151          (version, tag, "with" if scale_after_normalization else "without",
152           "with" if shift_after_normalization else "without"), err)
153    self.assertLess(err, err_tolerance)
154
155  def _testBatchNormGradientInAllNeedConfigs(self,
156                                             param_index,
157                                             tag,
158                                             err_tolerance=1e-11):
159    for scale_after_normalization in [True, False]:
160      for shift_after_normalization in [True, False]:
161        # shift_after_normalization=False is not supported in version 1.
162        for v in ([1, 2] if shift_after_normalization else [2]):
163          self._testBatchNormGradient(param_index, tag,
164                                      scale_after_normalization,
165                                      shift_after_normalization, v,
166                                      err_tolerance)
167
168  @test_util.run_deprecated_v1
169  def testBatchNormInputGradient(self):
170    self._testBatchNormGradientInAllNeedConfigs(0, "x")
171
172  @test_util.run_deprecated_v1
173  def testBatchNormMeanGradient(self):
174    self._testBatchNormGradientInAllNeedConfigs(1, "mean")
175
176  @test_util.run_deprecated_v1
177  def testBatchNormVarianceGradient(self):
178    self._testBatchNormGradientInAllNeedConfigs(
179        2, "variance", err_tolerance=1e-03)
180
181  @test_util.run_deprecated_v1
182  def testBatchNormBetaGradient(self):
183    # Since beta does not exist when scale_after_normalization=False, we only
184    # test for scale_after_normalization=True.
185    for scale_after_normalization in [True, False]:
186      for v in [1, 2]:
187        self._testBatchNormGradient(3, "beta", scale_after_normalization, True,
188                                    v)
189
190  @test_util.run_deprecated_v1
191  def testBatchNormGammaGradient(self):
192    # If scale_after_normalization is False, backprop for gamma in v1
193    # will be 0. In version 2 of the API, if scale_after_normalization is False,
194    # gamma is not used at all, and the gradient is None, which displeases the
195    # gradient checker.
196    for scale_after_normalization in [True, False]:
197      self._testBatchNormGradient(4, "gamma", scale_after_normalization, True,
198                                  1)
199    for shift_after_normalization in [True, False]:
200      self._testBatchNormGradient(4, "gamma", True, shift_after_normalization,
201                                  2)
202
203  @test_util.run_deprecated_v1
204  def testBatchNormGradImpl(self):
205    x_shape = [7, 5, 4, 6]
206    param_shape = [6]
207    np.random.seed(1)  # Make it reproducible.
208    x_val = np.random.random_sample(x_shape).astype(np.float32)
209    m_val = np.random.random_sample(param_shape).astype(np.float32)
210    v_val = np.random.random_sample(param_shape).astype(np.float32)
211    beta_val = np.random.random_sample(param_shape).astype(np.float32)
212    gamma_val = np.random.random_sample(param_shape).astype(np.float32)
213    backprop_val = np.random.random_sample(x_shape).astype(np.float32)
214    for use_gpu in [False, True]:
215      with self.cached_session(use_gpu=use_gpu) as sess:
216        x = constant_op.constant(x_val, name="x")
217        m = constant_op.constant(m_val, name="m")
218        v = constant_op.constant(v_val, name="v")
219        beta = constant_op.constant(beta_val, name="beta")
220        gamma = constant_op.constant(gamma_val, name="gamma")
221        backprop = constant_op.constant(backprop_val, name="backprop")
222        epsilon = 0.001
223        for scale_after_normalization in [True, False]:
224          # _batch_norm_with_global_normalization_grad is deprecated in v9
225          test_util.set_producer_version(ops.get_default_graph(), 8)
226          grad = gen_nn_ops.batch_norm_with_global_normalization_grad(
227              x, m, v, gamma, backprop, epsilon, scale_after_normalization)
228          dx, dm, dv, db, dg = grad
229          self.assertEqual(grad.dx, dx)
230          self.assertEqual(grad.dm, dm)
231          self.assertEqual(grad.dv, dv)
232          self.assertEqual(grad.db, db)
233          self.assertEqual(grad.dg, dg)
234
235          on = self._opsBatchNorm(x, m, v, beta, gamma, epsilon,
236                                  scale_after_normalization, True)
237          odx, odm, odv, odb, odg = gradients_impl.gradients(
238              [on], [x, m, v, beta, gamma], [backprop])
239          if scale_after_normalization:
240            all_grads = self.evaluate(
241                [dx, dm, dv, db, dg, odx, odm, odv, odb, odg])
242            to_check = ["dx", "dm", "dv", "db", "dg"]
243          else:
244            all_grads = self.evaluate([dx, dm, dv, db, odx, odm, odv, odb])
245            to_check = ["dx", "dm", "dv", "db"]
246          for i, _ in enumerate(to_check):
247            self.assertAllClose(
248                all_grads[i + len(to_check)], all_grads[i], atol=0.000001)
249
250  @test_util.run_deprecated_v1
251  def testBatchNormKeepDims(self):
252    """Test for tf.nn.moments(..., keep_dims=True / False).
253
254    Make sure that parameters with shape (1, 1, 1, depth) yield the same
255    result as parameters with shape (depth)
256    """
257    x_shape = (3, 5, 4, 2)
258    param_shape = (2)
259    keep_dims_param_shape = (1, 1, 1, 2)
260    x_val = np.random.random_sample(x_shape).astype(np.float32)
261    m_val = np.random.random_sample(param_shape).astype(np.float32)
262    v_val = np.random.random_sample(param_shape).astype(np.float32)
263    beta_val = np.random.random_sample(param_shape).astype(np.float32)
264    gamma_val = np.random.random_sample(param_shape).astype(np.float32)
265    for use_gpu in [True, False]:
266      with self.cached_session(use_gpu=use_gpu) as sess:
267        x = constant_op.constant(x_val, name="x")
268        m = constant_op.constant(m_val, name="m")
269        v = constant_op.constant(v_val, name="v")
270        beta = constant_op.constant(beta_val, name="beta")
271        gamma = constant_op.constant(gamma_val, name="gamma")
272        keep_dims_m = array_ops.reshape(
273            m, keep_dims_param_shape, name="keep_dims_m")
274        keep_dims_v = array_ops.reshape(
275            v, keep_dims_param_shape, name="keep_dims_v")
276        keep_dims_beta = array_ops.reshape(
277            beta, keep_dims_param_shape, name="keep_dims_beta")
278        keep_dims_gamma = array_ops.reshape(
279            gamma, keep_dims_param_shape, name="keep_dims_gamma")
280        epsilon = 0.001
281        for scale_after_normalization in [True, False]:
282          for shift_after_normalization in [True, False]:
283            bn = self._tfBatchNormV2(x, m, v, beta, gamma, epsilon,
284                                     scale_after_normalization,
285                                     shift_after_normalization)
286            keep_dims_bn = self._tfBatchNormV2(x, keep_dims_m, keep_dims_v,
287                                               keep_dims_beta, keep_dims_gamma,
288                                               epsilon,
289                                               scale_after_normalization,
290                                               shift_after_normalization)
291            tf_batch_norm, keep_dims_tf_batch_norm = sess.run(
292                [bn, keep_dims_bn])
293            self.assertEqual(x_shape, tf_batch_norm.shape)
294            self.assertEqual(x_shape, keep_dims_tf_batch_norm.shape)
295            self.assertAllClose(
296                tf_batch_norm, keep_dims_tf_batch_norm, atol=0.000001)
297
298  def _testBatchNormArbitraryShapes(self, x_shape, param_shape, atol=0.0001,
299                                    dtype=dtypes.float32,
300                                    param_dtype=dtypes.float32):
301    numpy_dtype = dtype.as_numpy_dtype
302    numpy_param_dtype = param_dtype.as_numpy_dtype
303    x_val = np.random.random_sample(x_shape).astype(numpy_dtype)
304    m_val = np.random.random_sample(param_shape).astype(numpy_param_dtype)
305    v_val = np.random.random_sample(param_shape).astype(numpy_param_dtype)
306    beta_val = np.random.random_sample(param_shape).astype(numpy_param_dtype)
307    gamma_val = np.random.random_sample(param_shape).astype(numpy_param_dtype)
308    for use_gpu in [True, False]:
309      with self.cached_session(use_gpu=use_gpu) as sess:
310        x = constant_op.constant(x_val, name="x")
311        m = constant_op.constant(m_val, name="m")
312        v = constant_op.constant(v_val, name="v")
313        beta = constant_op.constant(beta_val, name="beta")
314        gamma = constant_op.constant(gamma_val, name="gamma")
315        epsilon = 0.001
316        for scale_after_normalization in [True, False]:
317          for shift_after_normalization in [True, False]:
318            bn = self._tfBatchNormV2(x, m, v, beta, gamma, epsilon,
319                                     scale_after_normalization,
320                                     shift_after_normalization)
321            np_batch_norm = self._npBatchNorm(x_val, m_val, v_val, beta_val,
322                                              gamma_val, epsilon,
323                                              scale_after_normalization,
324                                              shift_after_normalization)
325            [tf_batch_norm] = self.evaluate([bn])
326            self.assertEqual(x_shape, np_batch_norm.shape)
327            self.assertEqual(x_shape, tf_batch_norm.shape)
328            self.assertAllClose(np_batch_norm, tf_batch_norm, atol=atol)
329
330  def testBatchNormArbitraryShapes(self):
331    """Test for a variety of shapes and moments.
332
333    Batch normalization is expected to work regardless of the position and
334    dimensionality of the 'depth' axis/axes.
335    """
336    self._testBatchNormArbitraryShapes((3, 3), (1, 3))
337    self._testBatchNormArbitraryShapes((3, 3), (3, 1))
338    self._testBatchNormArbitraryShapes((3, 2, 4, 5), (1, 2, 1, 1))
339    self._testBatchNormArbitraryShapes(
340        (2, 3, 2, 4, 5), (1, 1, 1, 4, 5), atol=0.005)
341
342  def testBatchNormMixedPrecision(self):
343    self._testBatchNormArbitraryShapes((3, 3), (1, 3), dtype=dtypes.float16,
344                                       param_dtype=dtypes.float32, atol=0.001)
345
346
347class SufficientStatisticsTest(test.TestCase):
348
349  def _npSuffStats(self, x, axes, shift, keep_dims):
350    axis = tuple(axes)
351    if shift is not None:
352      m_ss = np.sum(x - shift, axis=axis, keepdims=keep_dims)
353      v_ss = np.sum((x - shift) * (x - shift), axis=axis, keepdims=keep_dims)
354    else:
355      m_ss = np.sum(x, axis=axis, keepdims=keep_dims)
356      v_ss = np.sum(x * x, axis=axis, keepdims=keep_dims)
357    count = 1.0
358    for d in range(x.ndim):
359      if d in set(axes):
360        count *= x.shape[d]
361    if not keep_dims:
362      shift = np.asarray(shift)
363    return count, m_ss, v_ss, shift
364
365  def _opSuffStats(self, x, axes, shift, keep_dims):
366    return nn_impl.sufficient_statistics(x, axes, shift, keep_dims)
367
368  def _testSuffStats(self, x_shape, axes, shift, keep_dims, has_shape):
369    x_val = np.random.random_sample(x_shape).astype(np.float32)
370    np_c, np_m, np_v, np_s = self._npSuffStats(x_val, axes, shift, keep_dims)
371    for use_gpu in [True, False]:
372      with self.cached_session(use_gpu=use_gpu) as sess:
373        if has_shape:
374          x = constant_op.constant(x_val, name="x")
375          x.set_shape(x_shape)
376          op_c, op_m, op_v, op_s = self._opSuffStats(x, axes, shift, keep_dims)
377          if shift:
378            tf_c, tf_m, tf_v, tf_s = self.evaluate([op_c, op_m, op_v, op_s])
379          else:
380            tf_c, tf_m, tf_v = self.evaluate([op_c, op_m, op_v])
381        else:
382          x = array_ops.placeholder(
383              dtype=dtypes.float32, shape=[None] * len(x_shape), name="x")
384          op_c, op_m, op_v, op_s = self._opSuffStats(x, axes, shift, keep_dims)
385          if shift:
386            tf_c, tf_m, tf_v, tf_s = sess.run([op_c, op_m, op_v, op_s],
387                                              feed_dict={x: x_val})
388          else:
389            tf_c, tf_m, tf_v = sess.run([op_c, op_m, op_v],
390                                        feed_dict={x: x_val})
391        self.assertAllClose(np_c, tf_c, atol=0.000001)
392        self.assertAllClose(np_m, tf_m, atol=0.000001)
393        self.assertAllClose(np_v, tf_v, atol=0.000001)
394        if shift:
395          self.assertAllClose(np_s, tf_s, atol=0.000001)
396
397  @test_util.run_deprecated_v1
398  def testSuffStats(self):
399    for has_shape in [True, False]:
400      for keep_dims in [True, False]:
401        for shift in [None, 1.0]:
402          self._testSuffStats([2, 3], [1], shift, keep_dims, has_shape)
403          self._testSuffStats([2, 3], [0], shift, keep_dims, has_shape)
404          self._testSuffStats([1, 2, 3], [0, 2], shift, keep_dims, has_shape)
405
406
407class NormalizeMomentsTest(test.TestCase):
408
409  def _npNormalizeMoments(self, counts, mean_ss, variance_ss, shift):
410    mean = mean_ss / counts
411    variance = variance_ss / counts - mean * mean
412    if shift is not None:
413      mean += shift
414    return mean, variance
415
416  def _opNormalizeMoments(self, counts, mean_ss, variance_ss, shift):
417    return nn_impl.normalize_moments(counts, mean_ss, variance_ss, shift)
418
419  def _testNormalizeMoments(self, shape, shift):
420    counts = np.ones([1]).astype(np.float32)
421    mean_ss = np.random.random_sample(shape).astype(np.float32)
422    variance_ss = np.random.random_sample(shape).astype(np.float32)
423    variance_ss *= variance_ss
424    if shift:
425      shift_v = np.random.random_sample(shape).astype(np.float32)
426    else:
427      shift_v = None
428    npm, npv = self._npNormalizeMoments(counts, mean_ss, variance_ss, shift_v)
429    for use_gpu in [True, False]:
430      with self.cached_session(use_gpu=use_gpu) as sess:
431        tf_counts = constant_op.constant(counts, name="counts")
432        tf_mean_ss = constant_op.constant(mean_ss, name="mean_ss")
433        tf_variance_ss = constant_op.constant(variance_ss, name="variance_ss")
434        if shift:
435          tf_shift_v = constant_op.constant(shift_v, name="shift")
436        else:
437          tf_shift_v = None
438        opm, opv = self._opNormalizeMoments(tf_counts, tf_mean_ss,
439                                            tf_variance_ss, tf_shift_v)
440        tfm, tfv = self.evaluate([opm, opv])
441        self.assertAllClose(npm, tfm, atol=0.000001)
442        self.assertAllClose(npv, tfv, atol=0.000001)
443
444  def testNormalizeMoments(self):
445    for shift in [None, 4.0]:
446      self._testNormalizeMoments([3], shift)
447      self._testNormalizeMoments([2, 3], shift)
448
449
450class MomentsTest(test.TestCase):
451
452  def _unweighted_moments(self, x, axes, keep_dims=False, extra_out_grads=None):
453    # Method to compute moments of `x` wrt `axes`.
454    #
455    # This is exposed so WeightedMomentsTest can inherit the tests and
456    # assertions from MomentsTest; the extra_out_grads argument allows
457    # its inherited gradient tests to assert gradients against the
458    # weights as well as the input values.
459
460    return nn_impl.moments(x, axes, keep_dims=keep_dims)
461
462  def RunMomentTestWithDynamicShape(self, shape, axes, keep_dims, dtype):
463    with self.cached_session():
464      # shape = [batch, width, height, depth]
465      assert len(shape) == 4
466
467      x_numpy = np.random.normal(size=shape).astype(np.float32)
468      x = array_ops.placeholder(dtype, shape=[None] * len(shape))
469
470      mean, var = self._unweighted_moments(x, axes, keep_dims=keep_dims)
471
472      num_elements = np.prod([shape[i] for i in axes])
473
474      ax = tuple(axes)
475      expected_mean = np.sum(x_numpy, axis=ax,
476                             keepdims=keep_dims) / num_elements
477      expected_mean_squared = np.multiply(expected_mean, expected_mean)
478      expected_x_squared = np.sum(np.multiply(x_numpy, x_numpy),
479                                  axis=ax,
480                                  keepdims=keep_dims) / num_elements
481      expected_variance = expected_x_squared - expected_mean_squared
482
483      # Check that the moments are correct.
484      self.assertAllCloseAccordingToType(
485          expected_mean, mean.eval(feed_dict={x: x_numpy}))
486      self.assertAllCloseAccordingToType(
487          expected_variance, var.eval(feed_dict={x: x_numpy}))
488
489  def RunMomentTest(self, shape, axes, keep_dims, dtype):
490    with self.cached_session():
491      # shape = [batch, width, height, depth]
492      assert len(shape) == 4
493
494      x_numpy = np.random.normal(size=shape).astype(np.float32)
495      x = math_ops.cast(constant_op.constant(x_numpy), dtype=dtype)
496
497      # Compute the expected values at high precision since the method
498      # is prone to catastrophic cancellation:
499      x_numpy = x_numpy.astype(np.float128)
500
501      mean, var = self._unweighted_moments(x, axes, keep_dims=keep_dims)
502
503      num_elements = np.prod([shape[i] for i in axes])
504
505      ax = tuple(axes)
506      expected_mean = np.sum(x_numpy, axis=ax,
507                             keepdims=keep_dims) / num_elements
508      expected_mean_squared = np.multiply(expected_mean, expected_mean)
509      expected_x_squared = np.sum(np.multiply(x_numpy, x_numpy),
510                                  axis=ax,
511                                  keepdims=keep_dims) / num_elements
512      expected_variance = expected_x_squared - expected_mean_squared
513
514      # Check that the moments are correct.
515      self.assertAllCloseAccordingToType(expected_mean, self.evaluate(mean))
516      self.assertAllCloseAccordingToType(expected_variance, self.evaluate(var))
517
518  @test_util.run_deprecated_v1
519  def testBasic(self):
520    for keep_dims in [False, True]:
521      for dtype in [dtypes.float32, dtypes.float16]:
522        self.RunMomentTest(
523            shape=[2, 3, 5, 4], axes=[0], keep_dims=keep_dims, dtype=dtype)
524        self.RunMomentTestWithDynamicShape(
525            shape=[2, 3, 5, 4], axes=[0], keep_dims=keep_dims, dtype=dtype)
526
527  @test_util.run_deprecated_v1
528  def testGlobalNormalization(self):
529    for keep_dims in [False, True]:
530      for dtype in [dtypes.float32, dtypes.float16]:
531        self.RunMomentTest(
532            shape=[2, 3, 5, 4],
533            axes=[0, 1, 2],
534            keep_dims=keep_dims,
535            dtype=dtype)
536        self.RunMomentTestWithDynamicShape(
537            shape=[2, 3, 5, 4],
538            axes=[0, 1, 2],
539            keep_dims=keep_dims,
540            dtype=dtype)
541
542  @test_util.run_deprecated_v1
543  def testAxes(self):
544    for keep_dims in [False, True]:
545      for dtype in [dtypes.float32, dtypes.float16]:
546        self.RunMomentTest(
547            shape=[2, 3, 5, 4],
548            axes=[1, 2, 3],
549            keep_dims=keep_dims,
550            dtype=dtype)
551        self.RunMomentTestWithDynamicShape(
552            shape=[2, 3, 5, 4],
553            axes=[1, 2, 3],
554            keep_dims=keep_dims,
555            dtype=dtype)
556
557  def _testGlobalGradient(self, from_y="mean"):
558    with self.cached_session():
559      x_shape = [3, 5, 4, 2]
560      x_val = np.random.random_sample(x_shape).astype(np.float64)
561      x = constant_op.constant(x_val)
562      x.set_shape(x_shape)
563
564      axes = [0, 1, 2]
565      y_shape = [2]  # Depth of x
566
567      inputs_to_compute_gradients_for = [x]
568
569      out_mean, out_var = self._unweighted_moments(
570          x, axes, extra_out_grads=inputs_to_compute_gradients_for)
571      if from_y == "mean":
572        y = out_mean
573      elif from_y == "var":
574        y = out_var
575
576      for (i, v) in enumerate(inputs_to_compute_gradients_for):
577        err = gradient_checker.compute_gradient_error(v,
578                                                      v.get_shape().as_list(),
579                                                      y, y_shape)
580        print("Moments %s gradient err vs input %d = %g" % (from_y, i, err))
581        self.assertLess(err, 1e-11)
582
583  @test_util.run_deprecated_v1
584  def testMeanGlobalGradient(self):
585    self._testGlobalGradient(from_y="mean")
586
587  @test_util.run_deprecated_v1
588  def testVarGlobalGradient(self):
589    self._testGlobalGradient(from_y="var")
590
591
592class WeightedMomentsTest(MomentsTest):
593  """Tests for nn.weighted_moments.
594
595  Note that this test inherits from MomentsTest, inheriting all its
596  test methods!
597
598  It modifies MomentsTest in two ways:
599
600  a) By overriding _unweighted_moments, all the codepaths in
601     MomentsTest are executed, but with calls to tf.nn.moments()
602     replaced by calls to tf.nn.weighted_moments() with a constant
603     weight of 1.
604
605  b) By overriding RunMomentTest and RunMomentTestWithDynamicShape,
606     this test adds multiple additional calls to
607     RunWeightedMomentsTest() to exercise correctness with
608     non-constant weights and varying broadcasting situations. (It
609     also continues to call MomentsTest.Run(Weighted)?MomentsTest as
610     well.)
611
612  """
613
614  def _unweighted_moments(self, x, axes, keep_dims=False, extra_out_grads=None):
615    weights = constant_op.constant(1, dtype=x.dtype)
616    if extra_out_grads is not None:
617      # We want to assert gradients WRT weights as well as X!
618      extra_out_grads.append(weights)
619    return nn_impl.weighted_moments(x, axes, weights, keep_dims=keep_dims)
620
621  def RunMomentTest(self, shape, axes, keep_dims, dtype, dynshapes=False):
622    if not dynshapes:
623      super(WeightedMomentsTest, self).RunMomentTest(shape, axes, keep_dims,
624                                                     dtype)
625    else:
626      super(WeightedMomentsTest, self).RunMomentTestWithDynamicShape(shape,
627                                                                     axes,
628                                                                     keep_dims,
629                                                                     dtype)
630
631    # 1:1 weights and inputs
632    self.RunWeightedMomentTest(shape, shape, axes, keep_dims, dtype)
633
634    # Various broadcasting combinations
635    for idx in range(len(shape)):
636      # try broadcasting weights in all positions
637      weight_shape = [1] * len(shape)
638      weight_shape[idx] = shape[idx]
639
640      self.RunWeightedMomentTest(shape, weight_shape, axes, keep_dims, dtype)
641
642      # Also try broadcasting with a suffix of length n
643      weight_shape = shape[-(idx + 1):]
644      self.RunWeightedMomentTest(
645          shape, weight_shape, axes, keep_dims, dtype, dynshapes=dynshapes)
646
647  def RunMomentTestWithDynamicShape(self, shape, axes, keep_dims, dtype):
648    self.RunMomentTest(shape, axes, keep_dims, dtype, dynshapes=True)
649
650  def RunWeightedMomentTest(self,
651                            shape,
652                            weights_shape,
653                            axes,
654                            keep_dims,
655                            dtype,
656                            dynshapes=False):
657    with self.cached_session() as s:
658      x_numpy = np.random.normal(size=shape).astype(np.float32)
659      weights_numpy = np.absolute(  # weights must be positive
660          np.random.normal(
661              size=weights_shape, loc=1.0).astype(np.float32))
662
663      # Expand the numpy version to higher precision
664      x_numpy = x_numpy.astype(np.float128)
665      weights_numpy = weights_numpy.astype(np.float128)
666
667      x_shape = [None] * len(shape) if dynshapes else shape
668      weights_shape = ([None] * len(weights_shape) if dynshapes else
669                       weights_shape)
670
671      x = array_ops.placeholder(dtype, shape=x_shape)
672      weights = array_ops.placeholder(dtype, shape=weights_shape)
673
674      mean, var = nn_impl.weighted_moments(
675          x, axes, weights, keep_dims=keep_dims)
676
677      ax = tuple(axes)
678
679      def _np_weighted_sum(v):
680        return np.sum(weights_numpy * v, axis=ax, keepdims=keep_dims)
681
682      weight_sum = _np_weighted_sum(np.ones_like(x_numpy))
683      expected_mean = _np_weighted_sum(x_numpy) / weight_sum
684      expected_mean_squared = np.multiply(expected_mean, expected_mean)
685      expected_x_squared = (_np_weighted_sum(np.multiply(x_numpy, x_numpy)) /
686                            weight_sum)
687      expected_variance = expected_x_squared - expected_mean_squared
688
689      mean_v, var_v = s.run([mean, var],
690                            feed_dict={x: x_numpy,
691                                       weights: weights_numpy})
692
693      self.assertAllCloseAccordingToType(expected_mean, mean_v)
694      self.assertAllCloseAccordingToType(expected_variance, var_v)
695
696
697if __name__ == "__main__":
698  test.main()
699