• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for batch_norm related functionality in tensorflow.ops.nn."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import numpy as np
22from six.moves import xrange  # pylint: disable=redefined-builtin
23
24from tensorflow.python.framework import constant_op
25from tensorflow.python.framework import dtypes
26from tensorflow.python.framework import ops
27from tensorflow.python.framework import test_util
28from tensorflow.python.ops import array_ops
29from tensorflow.python.ops import gen_nn_ops
30from tensorflow.python.ops import gradient_checker
31from tensorflow.python.ops import gradients_impl
32from tensorflow.python.ops import math_ops
33from tensorflow.python.ops import nn_impl
34import tensorflow.python.ops.nn_grad  # pylint: disable=unused-import
35from tensorflow.python.platform import test
36
37
38class BatchNormalizationTest(test.TestCase):
39
40  def _npBatchNorm(self, x, m, v, beta, gamma, epsilon,
41                   scale_after_normalization, shift_after_normalization):
42    y = (x - m) / np.sqrt(v + epsilon)
43    y = y * gamma if scale_after_normalization else y
44    return y + beta if shift_after_normalization else y
45
46  def _opsBatchNorm(self, x, m, v, beta, gamma, epsilon,
47                    scale_after_normalization, shift_after_normalization):
48    y = (x - m) * math_ops.rsqrt(v + epsilon)
49    if scale_after_normalization:
50      y = gamma * y
51    return y + beta if shift_after_normalization else y
52
53  def _tfBatchNormV1(self, x, m, v, beta, gamma, epsilon,
54                     scale_after_normalization):
55    """Original implementation."""
56    test_util.set_producer_version(ops.get_default_graph(), 8)
57    return gen_nn_ops._batch_norm_with_global_normalization(
58        x, m, v, beta, gamma, epsilon, scale_after_normalization)
59
60  def _tfBatchNormV1BW(self, x, m, v, beta, gamma, epsilon,
61                       scale_after_normalization):
62    """Re-implementation of the original kernel for backward compatibility."""
63    return nn_impl.batch_norm_with_global_normalization(
64        x, m, v, beta, gamma, epsilon, scale_after_normalization)
65
66  def _tfBatchNormV2(self, x, m, v, beta, gamma, epsilon,
67                     scale_after_normalization, shift_after_normalization):
68    """New implementation."""
69    return nn_impl.batch_normalization(x, m, v, beta if
70                                       shift_after_normalization else None,
71                                       gamma if scale_after_normalization else
72                                       None, epsilon)
73
74  @test_util.run_deprecated_v1
75  def testBatchNorm(self):
76    x_shape = [3, 5, 4, 2]
77    param_shape = [2]
78    x_val = np.random.random_sample(x_shape).astype(np.float32)
79    m_val = np.random.random_sample(param_shape).astype(np.float32)
80    v_val = np.random.random_sample(param_shape).astype(np.float32)
81    beta_val = np.random.random_sample(param_shape).astype(np.float32)
82    gamma_val = np.random.random_sample(param_shape).astype(np.float32)
83    for use_gpu in [True, False]:
84      with self.cached_session(use_gpu=use_gpu) as sess:
85        x = constant_op.constant(x_val, name="x")
86        m = constant_op.constant(m_val, name="m")
87        v = constant_op.constant(v_val, name="v")
88        beta = constant_op.constant(beta_val, name="beta")
89        gamma = constant_op.constant(gamma_val, name="gamma")
90        epsilon = 0.001
91        for scale_after_normalization in [True, False]:
92          for shift_after_normalization in [True, False]:
93            bn2 = self._tfBatchNormV2(x, m, v, beta, gamma, epsilon,
94                                      scale_after_normalization,
95                                      shift_after_normalization)
96            bn1bw = self._tfBatchNormV1BW(x, m, v, beta, gamma, epsilon,
97                                          scale_after_normalization)
98            bn1 = self._tfBatchNormV1(x, m, v, beta, gamma, epsilon,
99                                      scale_after_normalization)
100            on = self._opsBatchNorm(x, m, v, beta, gamma, epsilon,
101                                    scale_after_normalization,
102                                    shift_after_normalization)
103            np_bn = self._npBatchNorm(x_val, m_val, v_val, beta_val, gamma_val,
104                                      epsilon, scale_after_normalization,
105                                      shift_after_normalization)
106            tf_bn_v2, tf_bn_v1bw, tf_bn_v1, ops_bn = sess.run(
107                [bn2, bn1bw, bn1, on])
108            self.assertAllClose(np_bn, ops_bn, atol=0.00001)
109            self.assertAllClose(np_bn, tf_bn_v2, atol=0.00001)
110            self.assertAllClose(tf_bn_v2, ops_bn, atol=0.00001)
111            # shift_after_normalization=False is not supported in v1.
112            if shift_after_normalization:
113              self.assertAllClose(np_bn, tf_bn_v1bw, atol=0.00001)
114              self.assertAllClose(np_bn, tf_bn_v1, atol=0.00001)
115              self.assertAllClose(tf_bn_v1, ops_bn, atol=0.00001)
116              self.assertAllClose(tf_bn_v1bw, ops_bn, atol=0.00001)
117
118  def _testBatchNormGradient(self,
119                             param_index,
120                             tag,
121                             scale_after_normalization,
122                             shift_after_normalization,
123                             version,
124                             err_tolerance=1e-11):
125    x_shape = [3, 5, 4, 5]
126    param_shape = [5]
127    np.random.seed(1)  # Make it reproducible.
128    x_val = np.random.random_sample(x_shape).astype(np.float64)
129    m_val = np.random.random_sample(param_shape).astype(np.float64)
130    v_val = np.random.random_sample(param_shape).astype(np.float64)
131    beta_val = np.random.random_sample(param_shape).astype(np.float64)
132    gamma_val = np.random.random_sample(param_shape).astype(np.float64)
133    with self.cached_session():
134      x = constant_op.constant(x_val, name="x")
135      m = constant_op.constant(m_val, name="m")
136      v = constant_op.constant(v_val, name="v")
137      beta = constant_op.constant(beta_val, name="beta")
138      gamma = constant_op.constant(gamma_val, name="gamma")
139      epsilon = 0.001
140      if version == 1:
141        output = self._tfBatchNormV1(x, m, v, beta, gamma, epsilon,
142                                     scale_after_normalization)
143      elif version == 2:
144        output = self._tfBatchNormV2(x, m, v, beta, gamma, epsilon,
145                                     scale_after_normalization,
146                                     shift_after_normalization)
147      else:
148        print("Invalid version", version)
149        raise ValueError()
150      all_params = [x, m, v, beta, gamma]
151      all_shapes = [x_shape, param_shape, param_shape, param_shape, param_shape]
152      err = gradient_checker.compute_gradient_error(all_params[param_index],
153                                                    all_shapes[param_index],
154                                                    output, x_shape)
155    print("Batch normalization v%d %s gradient %s scale and %s shift err = " %
156          (version, tag, "with" if scale_after_normalization else "without",
157           "with" if shift_after_normalization else "without"), err)
158    self.assertLess(err, err_tolerance)
159
160  def _testBatchNormGradientInAllNeedConfigs(self,
161                                             param_index,
162                                             tag,
163                                             err_tolerance=1e-11):
164    for scale_after_normalization in [True, False]:
165      for shift_after_normalization in [True, False]:
166        # shift_after_normalization=False is not supported in version 1.
167        for v in ([1, 2] if shift_after_normalization else [2]):
168          self._testBatchNormGradient(param_index, tag,
169                                      scale_after_normalization,
170                                      shift_after_normalization, v,
171                                      err_tolerance)
172
173  @test_util.run_deprecated_v1
174  def testBatchNormInputGradient(self):
175    self._testBatchNormGradientInAllNeedConfigs(0, "x")
176
177  @test_util.run_deprecated_v1
178  def testBatchNormMeanGradient(self):
179    self._testBatchNormGradientInAllNeedConfigs(1, "mean")
180
181  @test_util.run_deprecated_v1
182  def testBatchNormVarianceGradient(self):
183    self._testBatchNormGradientInAllNeedConfigs(
184        2, "variance", err_tolerance=1e-03)
185
186  @test_util.run_deprecated_v1
187  def testBatchNormBetaGradient(self):
188    # Since beta does not exist when scale_after_normalization=False, we only
189    # test for scale_after_normalization=True.
190    for scale_after_normalization in [True, False]:
191      for v in [1, 2]:
192        self._testBatchNormGradient(3, "beta", scale_after_normalization, True,
193                                    v)
194
195  @test_util.run_deprecated_v1
196  def testBatchNormGammaGradient(self):
197    # If scale_after_normalization is False, backprop for gamma in v1
198    # will be 0. In version 2 of the API, if scale_after_normalization is False,
199    # gamma is not used at all, and the gradient is None, which displeases the
200    # gradient checker.
201    for scale_after_normalization in [True, False]:
202      self._testBatchNormGradient(4, "gamma", scale_after_normalization, True,
203                                  1)
204    for shift_after_normalization in [True, False]:
205      self._testBatchNormGradient(4, "gamma", True, shift_after_normalization,
206                                  2)
207
208  @test_util.run_deprecated_v1
209  def testBatchNormGradImpl(self):
210    x_shape = [7, 5, 4, 6]
211    param_shape = [6]
212    np.random.seed(1)  # Make it reproducible.
213    x_val = np.random.random_sample(x_shape).astype(np.float32)
214    m_val = np.random.random_sample(param_shape).astype(np.float32)
215    v_val = np.random.random_sample(param_shape).astype(np.float32)
216    beta_val = np.random.random_sample(param_shape).astype(np.float32)
217    gamma_val = np.random.random_sample(param_shape).astype(np.float32)
218    backprop_val = np.random.random_sample(x_shape).astype(np.float32)
219    for use_gpu in [False, True]:
220      with self.cached_session(use_gpu=use_gpu) as sess:
221        x = constant_op.constant(x_val, name="x")
222        m = constant_op.constant(m_val, name="m")
223        v = constant_op.constant(v_val, name="v")
224        beta = constant_op.constant(beta_val, name="beta")
225        gamma = constant_op.constant(gamma_val, name="gamma")
226        backprop = constant_op.constant(backprop_val, name="backprop")
227        epsilon = 0.001
228        for scale_after_normalization in [True, False]:
229          # _batch_norm_with_global_normalization_grad is deprecated in v9
230          test_util.set_producer_version(ops.get_default_graph(), 8)
231          grad = gen_nn_ops.batch_norm_with_global_normalization_grad(
232              x, m, v, gamma, backprop, epsilon, scale_after_normalization)
233          dx, dm, dv, db, dg = grad
234          self.assertEqual(grad.dx, dx)
235          self.assertEqual(grad.dm, dm)
236          self.assertEqual(grad.dv, dv)
237          self.assertEqual(grad.db, db)
238          self.assertEqual(grad.dg, dg)
239
240          on = self._opsBatchNorm(x, m, v, beta, gamma, epsilon,
241                                  scale_after_normalization, True)
242          odx, odm, odv, odb, odg = gradients_impl.gradients(
243              [on], [x, m, v, beta, gamma], [backprop])
244          if scale_after_normalization:
245            all_grads = self.evaluate(
246                [dx, dm, dv, db, dg, odx, odm, odv, odb, odg])
247            to_check = ["dx", "dm", "dv", "db", "dg"]
248          else:
249            all_grads = self.evaluate([dx, dm, dv, db, odx, odm, odv, odb])
250            to_check = ["dx", "dm", "dv", "db"]
251          for i, _ in enumerate(to_check):
252            self.assertAllClose(
253                all_grads[i + len(to_check)], all_grads[i], atol=0.000001)
254
255  @test_util.run_deprecated_v1
256  def testBatchNormKeepDims(self):
257    """Test for tf.nn.moments(..., keep_dims=True / False).
258
259    Make sure that parameters with shape (1, 1, 1, depth) yield the same
260    result as parameters with shape (depth)
261    """
262    x_shape = (3, 5, 4, 2)
263    param_shape = (2)
264    keep_dims_param_shape = (1, 1, 1, 2)
265    x_val = np.random.random_sample(x_shape).astype(np.float32)
266    m_val = np.random.random_sample(param_shape).astype(np.float32)
267    v_val = np.random.random_sample(param_shape).astype(np.float32)
268    beta_val = np.random.random_sample(param_shape).astype(np.float32)
269    gamma_val = np.random.random_sample(param_shape).astype(np.float32)
270    for use_gpu in [True, False]:
271      with self.cached_session(use_gpu=use_gpu) as sess:
272        x = constant_op.constant(x_val, name="x")
273        m = constant_op.constant(m_val, name="m")
274        v = constant_op.constant(v_val, name="v")
275        beta = constant_op.constant(beta_val, name="beta")
276        gamma = constant_op.constant(gamma_val, name="gamma")
277        keep_dims_m = array_ops.reshape(
278            m, keep_dims_param_shape, name="keep_dims_m")
279        keep_dims_v = array_ops.reshape(
280            v, keep_dims_param_shape, name="keep_dims_v")
281        keep_dims_beta = array_ops.reshape(
282            beta, keep_dims_param_shape, name="keep_dims_beta")
283        keep_dims_gamma = array_ops.reshape(
284            gamma, keep_dims_param_shape, name="keep_dims_gamma")
285        epsilon = 0.001
286        for scale_after_normalization in [True, False]:
287          for shift_after_normalization in [True, False]:
288            bn = self._tfBatchNormV2(x, m, v, beta, gamma, epsilon,
289                                     scale_after_normalization,
290                                     shift_after_normalization)
291            keep_dims_bn = self._tfBatchNormV2(x, keep_dims_m, keep_dims_v,
292                                               keep_dims_beta, keep_dims_gamma,
293                                               epsilon,
294                                               scale_after_normalization,
295                                               shift_after_normalization)
296            tf_batch_norm, keep_dims_tf_batch_norm = sess.run(
297                [bn, keep_dims_bn])
298            self.assertEquals(x_shape, tf_batch_norm.shape)
299            self.assertEquals(x_shape, keep_dims_tf_batch_norm.shape)
300            self.assertAllClose(
301                tf_batch_norm, keep_dims_tf_batch_norm, atol=0.000001)
302
303  def _testBatchNormArbitraryShapes(self, x_shape, param_shape, atol=0.0001,
304                                    dtype=dtypes.float32,
305                                    param_dtype=dtypes.float32):
306    numpy_dtype = dtype.as_numpy_dtype
307    numpy_param_dtype = param_dtype.as_numpy_dtype
308    x_val = np.random.random_sample(x_shape).astype(numpy_dtype)
309    m_val = np.random.random_sample(param_shape).astype(numpy_param_dtype)
310    v_val = np.random.random_sample(param_shape).astype(numpy_param_dtype)
311    beta_val = np.random.random_sample(param_shape).astype(numpy_param_dtype)
312    gamma_val = np.random.random_sample(param_shape).astype(numpy_param_dtype)
313    for use_gpu in [True, False]:
314      with self.cached_session(use_gpu=use_gpu) as sess:
315        x = constant_op.constant(x_val, name="x")
316        m = constant_op.constant(m_val, name="m")
317        v = constant_op.constant(v_val, name="v")
318        beta = constant_op.constant(beta_val, name="beta")
319        gamma = constant_op.constant(gamma_val, name="gamma")
320        epsilon = 0.001
321        for scale_after_normalization in [True, False]:
322          for shift_after_normalization in [True, False]:
323            bn = self._tfBatchNormV2(x, m, v, beta, gamma, epsilon,
324                                     scale_after_normalization,
325                                     shift_after_normalization)
326            np_batch_norm = self._npBatchNorm(x_val, m_val, v_val, beta_val,
327                                              gamma_val, epsilon,
328                                              scale_after_normalization,
329                                              shift_after_normalization)
330            [tf_batch_norm] = self.evaluate([bn])
331            self.assertEquals(x_shape, np_batch_norm.shape)
332            self.assertEquals(x_shape, tf_batch_norm.shape)
333            self.assertAllClose(np_batch_norm, tf_batch_norm, atol=atol)
334
335  def testBatchNormArbitraryShapes(self):
336    """Test for a variety of shapes and moments.
337
338    Batch normalization is expected to work regardless of the position and
339    dimensionality of the 'depth' axis/axes.
340    """
341    self._testBatchNormArbitraryShapes((3, 3), (1, 3))
342    self._testBatchNormArbitraryShapes((3, 3), (3, 1))
343    self._testBatchNormArbitraryShapes((3, 2, 4, 5), (1, 2, 1, 1))
344    self._testBatchNormArbitraryShapes(
345        (2, 3, 2, 4, 5), (1, 1, 1, 4, 5), atol=0.005)
346
347  def testBatchNormMixedPrecision(self):
348    self._testBatchNormArbitraryShapes((3, 3), (1, 3), dtype=dtypes.float16,
349                                       param_dtype=dtypes.float32, atol=0.001)
350
351
352class SufficientStatisticsTest(test.TestCase):
353
354  def _npSuffStats(self, x, axes, shift, keep_dims):
355    axis = tuple(axes)
356    if shift is not None:
357      m_ss = np.sum(x - shift, axis=axis, keepdims=keep_dims)
358      v_ss = np.sum((x - shift) * (x - shift), axis=axis, keepdims=keep_dims)
359    else:
360      m_ss = np.sum(x, axis=axis, keepdims=keep_dims)
361      v_ss = np.sum(x * x, axis=axis, keepdims=keep_dims)
362    count = 1.0
363    for d in xrange(x.ndim):
364      if d in set(axes):
365        count *= x.shape[d]
366    if not keep_dims:
367      shift = np.squeeze(shift, axis=axis)
368    return count, m_ss, v_ss, shift
369
370  def _opSuffStats(self, x, axes, shift, keep_dims):
371    return nn_impl.sufficient_statistics(x, axes, shift, keep_dims)
372
373  def _testSuffStats(self, x_shape, axes, shift, keep_dims, has_shape):
374    x_val = np.random.random_sample(x_shape).astype(np.float32)
375    np_c, np_m, np_v, np_s = self._npSuffStats(x_val, axes, shift, keep_dims)
376    for use_gpu in [True, False]:
377      with self.cached_session(use_gpu=use_gpu) as sess:
378        if has_shape:
379          x = constant_op.constant(x_val, name="x")
380          x.set_shape(x_shape)
381          op_c, op_m, op_v, op_s = self._opSuffStats(x, axes, shift, keep_dims)
382          if shift:
383            tf_c, tf_m, tf_v, tf_s = self.evaluate([op_c, op_m, op_v, op_s])
384          else:
385            tf_c, tf_m, tf_v = self.evaluate([op_c, op_m, op_v])
386        else:
387          x = array_ops.placeholder(
388              dtype=dtypes.float32, shape=[None] * len(x_shape), name="x")
389          op_c, op_m, op_v, op_s = self._opSuffStats(x, axes, shift, keep_dims)
390          if shift:
391            tf_c, tf_m, tf_v, tf_s = sess.run([op_c, op_m, op_v, op_s],
392                                              feed_dict={x: x_val})
393          else:
394            tf_c, tf_m, tf_v = sess.run([op_c, op_m, op_v],
395                                        feed_dict={x: x_val})
396        self.assertAllClose(np_c, tf_c, atol=0.000001)
397        self.assertAllClose(np_m, tf_m, atol=0.000001)
398        self.assertAllClose(np_v, tf_v, atol=0.000001)
399        if shift:
400          self.assertAllClose(np_s, tf_s, atol=0.000001)
401
402  @test_util.run_deprecated_v1
403  def testSuffStats(self):
404    for has_shape in [True, False]:
405      for keep_dims in [True, False]:
406        for shift in [None, 1.0]:
407          self._testSuffStats([2, 3], [1], shift, keep_dims, has_shape)
408          self._testSuffStats([2, 3], [0], shift, keep_dims, has_shape)
409          self._testSuffStats([1, 2, 3], [0, 2], shift, keep_dims, has_shape)
410
411
412class NormalizeMomentsTest(test.TestCase):
413
414  def _npNormalizeMoments(self, counts, mean_ss, variance_ss, shift):
415    mean = mean_ss / counts
416    variance = variance_ss / counts - mean * mean
417    if shift is not None:
418      mean += shift
419    return mean, variance
420
421  def _opNormalizeMoments(self, counts, mean_ss, variance_ss, shift):
422    return nn_impl.normalize_moments(counts, mean_ss, variance_ss, shift)
423
424  def _testNormalizeMoments(self, shape, shift):
425    counts = np.ones([1]).astype(np.float32)
426    mean_ss = np.random.random_sample(shape).astype(np.float32)
427    variance_ss = np.random.random_sample(shape).astype(np.float32)
428    variance_ss *= variance_ss
429    if shift:
430      shift_v = np.random.random_sample(shape).astype(np.float32)
431    else:
432      shift_v = None
433    npm, npv = self._npNormalizeMoments(counts, mean_ss, variance_ss, shift_v)
434    for use_gpu in [True, False]:
435      with self.cached_session(use_gpu=use_gpu) as sess:
436        tf_counts = constant_op.constant(counts, name="counts")
437        tf_mean_ss = constant_op.constant(mean_ss, name="mean_ss")
438        tf_variance_ss = constant_op.constant(variance_ss, name="variance_ss")
439        if shift:
440          tf_shift_v = constant_op.constant(shift_v, name="shift")
441        else:
442          tf_shift_v = None
443        opm, opv = self._opNormalizeMoments(tf_counts, tf_mean_ss,
444                                            tf_variance_ss, tf_shift_v)
445        tfm, tfv = self.evaluate([opm, opv])
446        self.assertAllClose(npm, tfm, atol=0.000001)
447        self.assertAllClose(npv, tfv, atol=0.000001)
448
449  def testNormalizeMoments(self):
450    for shift in [None, 4.0]:
451      self._testNormalizeMoments([3], shift)
452      self._testNormalizeMoments([2, 3], shift)
453
454
455class MomentsTest(test.TestCase):
456
457  def _unweighted_moments(self, x, axes, keep_dims=False, extra_out_grads=None):
458    # Method to compute moments of `x` wrt `axes`.
459    #
460    # This is exposed so WeightedMomentsTest can inherit the tests and
461    # assertions from MomentsTest; the extra_out_grads argument allows
462    # its inherited gradient tests to assert gradients against the
463    # weights as well as the input values.
464
465    return nn_impl.moments(x, axes, keep_dims=keep_dims)
466
467  def RunMomentTestWithDynamicShape(self, shape, axes, keep_dims, dtype):
468    with self.cached_session():
469      # shape = [batch, width, height, depth]
470      assert len(shape) == 4
471
472      x_numpy = np.random.normal(size=shape).astype(np.float32)
473      x = array_ops.placeholder(dtype, shape=[None] * len(shape))
474
475      mean, var = self._unweighted_moments(x, axes, keep_dims=keep_dims)
476
477      num_elements = np.prod([shape[i] for i in axes])
478
479      ax = tuple(axes)
480      expected_mean = np.sum(x_numpy, axis=ax,
481                             keepdims=keep_dims) / num_elements
482      expected_mean_squared = np.multiply(expected_mean, expected_mean)
483      expected_x_squared = np.sum(np.multiply(x_numpy, x_numpy),
484                                  axis=ax,
485                                  keepdims=keep_dims) / num_elements
486      expected_variance = expected_x_squared - expected_mean_squared
487
488      # Check that the moments are correct.
489      self.assertAllCloseAccordingToType(
490          expected_mean, mean.eval(feed_dict={x: x_numpy}))
491      self.assertAllCloseAccordingToType(
492          expected_variance, var.eval(feed_dict={x: x_numpy}))
493
494  def RunMomentTest(self, shape, axes, keep_dims, dtype):
495    with self.cached_session():
496      # shape = [batch, width, height, depth]
497      assert len(shape) == 4
498
499      x_numpy = np.random.normal(size=shape).astype(np.float32)
500      x = math_ops.cast(constant_op.constant(x_numpy), dtype=dtype)
501
502      # Compute the expected values at high precision since the method
503      # is prone to catastrophic cancellation:
504      x_numpy = x_numpy.astype(np.float128)
505
506      mean, var = self._unweighted_moments(x, axes, keep_dims=keep_dims)
507
508      num_elements = np.prod([shape[i] for i in axes])
509
510      ax = tuple(axes)
511      expected_mean = np.sum(x_numpy, axis=ax,
512                             keepdims=keep_dims) / num_elements
513      expected_mean_squared = np.multiply(expected_mean, expected_mean)
514      expected_x_squared = np.sum(np.multiply(x_numpy, x_numpy),
515                                  axis=ax,
516                                  keepdims=keep_dims) / num_elements
517      expected_variance = expected_x_squared - expected_mean_squared
518
519      # Check that the moments are correct.
520      self.assertAllCloseAccordingToType(expected_mean, self.evaluate(mean))
521      self.assertAllCloseAccordingToType(expected_variance, self.evaluate(var))
522
523  @test_util.run_deprecated_v1
524  def testBasic(self):
525    for keep_dims in [False, True]:
526      for dtype in [dtypes.float32, dtypes.float16]:
527        self.RunMomentTest(
528            shape=[2, 3, 5, 4], axes=[0], keep_dims=keep_dims, dtype=dtype)
529        self.RunMomentTestWithDynamicShape(
530            shape=[2, 3, 5, 4], axes=[0], keep_dims=keep_dims, dtype=dtype)
531
532  @test_util.run_deprecated_v1
533  def testGlobalNormalization(self):
534    for keep_dims in [False, True]:
535      for dtype in [dtypes.float32, dtypes.float16]:
536        self.RunMomentTest(
537            shape=[2, 3, 5, 4],
538            axes=[0, 1, 2],
539            keep_dims=keep_dims,
540            dtype=dtype)
541        self.RunMomentTestWithDynamicShape(
542            shape=[2, 3, 5, 4],
543            axes=[0, 1, 2],
544            keep_dims=keep_dims,
545            dtype=dtype)
546
547  @test_util.run_deprecated_v1
548  def testAxes(self):
549    for keep_dims in [False, True]:
550      for dtype in [dtypes.float32, dtypes.float16]:
551        self.RunMomentTest(
552            shape=[2, 3, 5, 4],
553            axes=[1, 2, 3],
554            keep_dims=keep_dims,
555            dtype=dtype)
556        self.RunMomentTestWithDynamicShape(
557            shape=[2, 3, 5, 4],
558            axes=[1, 2, 3],
559            keep_dims=keep_dims,
560            dtype=dtype)
561
562  def _testGlobalGradient(self, from_y="mean"):
563    with self.cached_session():
564      x_shape = [3, 5, 4, 2]
565      x_val = np.random.random_sample(x_shape).astype(np.float64)
566      x = constant_op.constant(x_val)
567      x.set_shape(x_shape)
568
569      axes = [0, 1, 2]
570      y_shape = [2]  # Depth of x
571
572      inputs_to_compute_gradients_for = [x]
573
574      out_mean, out_var = self._unweighted_moments(
575          x, axes, extra_out_grads=inputs_to_compute_gradients_for)
576      if from_y == "mean":
577        y = out_mean
578      elif from_y == "var":
579        y = out_var
580
581      for (i, v) in enumerate(inputs_to_compute_gradients_for):
582        err = gradient_checker.compute_gradient_error(v,
583                                                      v.get_shape().as_list(),
584                                                      y, y_shape)
585        print("Moments %s gradient err vs input %d = %g" % (from_y, i, err))
586        self.assertLess(err, 1e-11)
587
588  @test_util.run_deprecated_v1
589  def testMeanGlobalGradient(self):
590    self._testGlobalGradient(from_y="mean")
591
592  @test_util.run_deprecated_v1
593  def testVarGlobalGradient(self):
594    self._testGlobalGradient(from_y="var")
595
596
597class WeightedMomentsTest(MomentsTest):
598  """Tests for nn.weighted_moments.
599
600  Note that this test inherits from MomentsTest, inheriting all its
601  test methods!
602
603  It modifies MomentsTest in two ways:
604
605  a) By overriding _unweighted_moments, all the codepaths in
606     MomentsTest are executed, but with calls to tf.nn.moments()
607     replaced by calls to tf.nn.weighted_moments() with a constant
608     weight of 1.
609
610  b) By overriding RunMomentTest and RunMomentTestWithDynamicShape,
611     this test adds multiple additional calls to
612     RunWeightedMomentsTest() to exercise correctness with
613     non-constant weights and varying broadcasting situations. (It
614     also continues to call MomentsTest.Run(Weighted)?MomentsTest as
615     well.)
616
617  """
618
619  def _unweighted_moments(self, x, axes, keep_dims=False, extra_out_grads=None):
620    weights = constant_op.constant(1, dtype=x.dtype)
621    if extra_out_grads is not None:
622      # We want to assert gradients WRT weights as well as X!
623      extra_out_grads.append(weights)
624    return nn_impl.weighted_moments(x, axes, weights, keep_dims=keep_dims)
625
626  def RunMomentTest(self, shape, axes, keep_dims, dtype, dynshapes=False):
627    if not dynshapes:
628      super(WeightedMomentsTest, self).RunMomentTest(shape, axes, keep_dims,
629                                                     dtype)
630    else:
631      super(WeightedMomentsTest, self).RunMomentTestWithDynamicShape(shape,
632                                                                     axes,
633                                                                     keep_dims,
634                                                                     dtype)
635
636    # 1:1 weights and inputs
637    self.RunWeightedMomentTest(shape, shape, axes, keep_dims, dtype)
638
639    # Various broadcasting combinations
640    for idx in range(len(shape)):
641      # try broadcasting weights in all positions
642      weight_shape = [1] * len(shape)
643      weight_shape[idx] = shape[idx]
644
645      self.RunWeightedMomentTest(shape, weight_shape, axes, keep_dims, dtype)
646
647      # Also try broadcasting with a suffix of length n
648      weight_shape = shape[-(idx + 1):]
649      self.RunWeightedMomentTest(
650          shape, weight_shape, axes, keep_dims, dtype, dynshapes=dynshapes)
651
652  def RunMomentTestWithDynamicShape(self, shape, axes, keep_dims, dtype):
653    self.RunMomentTest(shape, axes, keep_dims, dtype, dynshapes=True)
654
655  def RunWeightedMomentTest(self,
656                            shape,
657                            weights_shape,
658                            axes,
659                            keep_dims,
660                            dtype,
661                            dynshapes=False):
662    with self.cached_session() as s:
663      x_numpy = np.random.normal(size=shape).astype(np.float32)
664      weights_numpy = np.absolute(  # weights must be positive
665          np.random.normal(
666              size=weights_shape, loc=1.0).astype(np.float32))
667
668      # Expand the numpy version to higher precision
669      x_numpy = x_numpy.astype(np.float128)
670      weights_numpy = weights_numpy.astype(np.float128)
671
672      x_shape = [None] * len(shape) if dynshapes else shape
673      weights_shape = ([None] * len(weights_shape) if dynshapes else
674                       weights_shape)
675
676      x = array_ops.placeholder(dtype, shape=x_shape)
677      weights = array_ops.placeholder(dtype, shape=weights_shape)
678
679      mean, var = nn_impl.weighted_moments(
680          x, axes, weights, keep_dims=keep_dims)
681
682      ax = tuple(axes)
683
684      def _np_weighted_sum(v):
685        return np.sum(weights_numpy * v, axis=ax, keepdims=keep_dims)
686
687      weight_sum = _np_weighted_sum(np.ones_like(x_numpy))
688      expected_mean = _np_weighted_sum(x_numpy) / weight_sum
689      expected_mean_squared = np.multiply(expected_mean, expected_mean)
690      expected_x_squared = (_np_weighted_sum(np.multiply(x_numpy, x_numpy)) /
691                            weight_sum)
692      expected_variance = expected_x_squared - expected_mean_squared
693
694      mean_v, var_v = s.run([mean, var],
695                            feed_dict={x: x_numpy,
696                                       weights: weights_numpy})
697
698      self.assertAllCloseAccordingToType(expected_mean, mean_v)
699      self.assertAllCloseAccordingToType(expected_variance, var_v)
700
701
702if __name__ == "__main__":
703  test.main()
704