• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Functional tests for pooling operations."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import numpy as np
22
23from tensorflow.compiler.tests import xla_test
24from tensorflow.python.framework import dtypes
25from tensorflow.python.framework import ops
26from tensorflow.python.ops import array_ops
27from tensorflow.python.ops import gen_nn_ops
28from tensorflow.python.ops import nn_ops
29from tensorflow.python.platform import googletest
30
31
32def NHWCToNCHW(input_tensor):
33  """Convert the input from NHWC format to NCHW.
34
35  Args:
36    input_tensor:  a 4-D tensor, or a 4-element array representing the same.
37
38  Returns:
39    the converted tensor or a shape array
40  """
41  if isinstance(input_tensor, ops.Tensor):
42    return array_ops.transpose(input_tensor, [0, 3, 1, 2])
43  else:
44    return [input_tensor[0], input_tensor[3], input_tensor[1], input_tensor[2]]
45
46
47def NCHWToNHWC(input_tensor):
48  """Convert the input from NCHW format to NHWC.
49
50  Args:
51    input_tensor:  a 4-D tensor, or a 4-element array representing the same.
52
53  Returns:
54    the converted tensor or a shape array
55  """
56  if isinstance(input_tensor, ops.Tensor):
57    return array_ops.transpose(input_tensor, [0, 2, 3, 1])
58  else:
59    return [input_tensor[0], input_tensor[2], input_tensor[3], input_tensor[1]]
60
61
62def GetTestConfigs():
63  """Get all the valid tests configs to run.
64
65  Returns:
66    all the valid test configs
67  """
68  test_configs = ["NHWC", "NCHW"]
69  return test_configs
70
71
72class PoolingTest(xla_test.XLATestCase):
73
74  def _VerifyOneTest(self, pool_func, input_sizes, ksize, strides, padding,
75                     data_format, expected):
76    """Verifies the output values of the pooling function.
77
78    Args:
79      pool_func: Function to be called, currently only co.MaxPool.
80      input_sizes: Input tensor dimensions.
81      ksize: The kernel size dimensions
82      strides: The stride dimensions
83      padding: Padding type.
84      data_format: The data format we use to run the pooling operation.
85      expected: An array containing the expected operation outputs.
86    """
87    total_size = np.prod(input_sizes)
88    # Initializes the input tensor with array containing incrementing
89    # numbers from 1.
90    x = np.array([f * 1.0 for f in range(1, total_size + 1)], dtype=np.float32)
91    x = x.reshape(input_sizes)
92    with self.cached_session() as sess:
93      with self.test_scope():
94        inputs = array_ops.placeholder(dtypes.float32)
95        t = inputs
96        if data_format == "NCHW":
97          t = NHWCToNCHW(t)
98          ksize = NHWCToNCHW(ksize)
99          strides = NHWCToNCHW(strides)
100        t = pool_func(t,
101                      ksize=ksize,
102                      strides=strides,
103                      padding=padding,
104                      data_format=data_format)
105        if data_format == "NCHW":
106          t = NCHWToNHWC(t)
107      actual = sess.run(t, {inputs: x})
108      self.assertAllClose(expected, actual.flatten(), rtol=1e-5, atol=1e-6)
109
110  def _VerifyValues(self, pool_func, input_sizes, ksize, strides, padding,
111                    expected):
112    """Verifies the output values of the pooling function.
113
114    Args:
115      pool_func: Function to be called, co.MaxPool, co.AvgPool,
116        or the Lua version.
117      input_sizes: Input tensor dimensions.
118      ksize: The kernel size dimensions
119      strides: The stride dimensions
120      padding: Padding type.
121      expected: An array containing the expected operation outputs.
122    """
123    for data_format in GetTestConfigs():
124      self._VerifyOneTest(pool_func, input_sizes, ksize, strides, padding,
125                          data_format, expected)
126
127  def testMaxPoolValidPadding(self):
128    expected_output = [13.0, 14.0, 15.0]
129    self._VerifyValues(nn_ops.max_pool,
130                       input_sizes=[1, 3, 3, 3],
131                       ksize=[1, 2, 2, 1],
132                       strides=[1, 2, 2, 1],
133                       padding="VALID",
134                       expected=expected_output)
135
136  def testMaxPoolSamePadding(self):
137    expected_output = [13.0, 14.0, 15.0, 16.0, 17.0, 18.0]
138    self._VerifyValues(nn_ops.max_pool,
139                       input_sizes=[1, 2, 3, 3],
140                       ksize=[1, 2, 2, 1],
141                       strides=[1, 2, 2, 1],
142                       padding="SAME",
143                       expected=expected_output)
144
145  def testMaxPoolSamePaddingNonSquareWindow(self):
146    # input is:
147    # [1.0, 2.0
148    #  3.0  4.0]
149    #
150    # Window of [x, x] should do:
151    #
152    #  [max(1.0, 2.0), max(2.0, padded0),
153    #   max(3.0, 4.0), max(4.0, padded0)]
154    self._VerifyValues(
155        nn_ops.max_pool,
156        input_sizes=[1, 2, 2, 1],
157        ksize=[1, 1, 2, 1],
158        strides=[1, 1, 1, 1],
159        padding="SAME",
160        expected=[2.0, 2.0, 4.0, 4.0])
161
162  def testMaxPoolValidPaddingUnevenStride(self):
163    self._VerifyValues(
164        nn_ops.max_pool,
165        input_sizes=[1, 4, 4, 1],
166        ksize=[1, 2, 2, 1],
167        strides=[1, 1, 2, 1],
168        padding="VALID",
169        expected=[6.0, 8.0, 10.0, 12.0, 14.0, 16.0])
170    self._VerifyValues(
171        nn_ops.max_pool,
172        input_sizes=[1, 4, 4, 1],
173        ksize=[1, 2, 2, 1],
174        strides=[1, 2, 1, 1],
175        padding="VALID",
176        expected=[6.0, 7.0, 8.0, 14.0, 15.0, 16.0])
177
178  def testMaxPoolSamePaddingFilter4(self):
179    expected_output = [
180        21.0, 22.0, 23.0, 24.0, 29.0, 30.0, 31.0, 32.0, 53.0, 54.0, 55.0, 56.0,
181        61.0, 62.0, 63.0, 64.0
182    ]
183    self._VerifyValues(
184        nn_ops.max_pool,
185        input_sizes=[1, 4, 4, 4],
186        ksize=[1, 2, 2, 1],
187        strides=[1, 2, 2, 1],
188        padding="SAME",
189        expected=expected_output)
190
191  def testMaxPoolSamePaddingFilter8(self):
192    expected_output = [
193        145.0, 146.0, 147.0, 148.0, 149.0, 150.0, 151.0, 152.0, 161.0, 162.0,
194        163.0, 164.0, 165.0, 166.0, 167.0, 168.0, 177.0, 178.0, 179.0, 180.0,
195        181.0, 182.0, 183.0, 184.0, 185.0, 186.0, 187.0, 188.0, 189.0, 190.0,
196        191.0, 192.0, 273.0, 274.0, 275.0, 276.0, 277.0, 278.0, 279.0, 280.0,
197        289.0, 290.0, 291.0, 292.0, 293.0, 294.0, 295.0, 296.0, 305.0, 306.0,
198        307.0, 308.0, 309.0, 310.0, 311.0, 312.0, 313.0, 314.0, 315.0, 316.0,
199        317.0, 318.0, 319.0, 320.0, 401.0, 402.0, 403.0, 404.0, 405.0, 406.0,
200        407.0, 408.0, 417.0, 418.0, 419.0, 420.0, 421.0, 422.0, 423.0, 424.0,
201        433.0, 434.0, 435.0, 436.0, 437.0, 438.0, 439.0, 440.0, 441.0, 442.0,
202        443.0, 444.0, 445.0, 446.0, 447.0, 448.0, 465.0, 466.0, 467.0, 468.0,
203        469.0, 470.0, 471.0, 472.0, 481.0, 482.0, 483.0, 484.0, 485.0, 486.0,
204        487.0, 488.0, 497.0, 498.0, 499.0, 500.0, 501.0, 502.0, 503.0, 504.0,
205        505.0, 506.0, 507.0, 508.0, 509.0, 510.0, 511.0, 512.0
206    ]
207    self._VerifyValues(
208        nn_ops.max_pool,
209        input_sizes=[1, 8, 8, 8],
210        ksize=[1, 3, 3, 1],
211        strides=[1, 2, 2, 1],
212        padding="SAME",
213        expected=expected_output)
214
215  # Tests for DepthwiseMaxPooling on CPU only.
216  def testDepthwiseMaxPool1x1DepthWindow1(self):
217    # input is:
218    # [1.0, ..., 10.0] along depth,
219    #
220    # We maxpool by depth in patches of 2.
221    self._VerifyValues(
222        nn_ops.max_pool,
223        input_sizes=[1, 1, 1, 10],
224        ksize=[1, 1, 1, 2],
225        strides=[1, 1, 1, 2],
226        padding="SAME",
227        expected=[2.0, 4.0, 6.0, 8.0, 10.0])
228
229  def testDepthwiseMaxPool2x2DepthWindow3(self):
230    # input is:
231    #
232    # a 2x2x6 cube, and we depthwise max across 3 to produce a 2x2x2
233    # output.  Each node has contiguous values, so the depthwise max
234    # should be multiples of 3.0.
235    self._VerifyValues(
236        nn_ops.max_pool,
237        input_sizes=[1, 2, 2, 6],
238        ksize=[1, 1, 1, 3],
239        strides=[1, 1, 1, 3],
240        padding="SAME",
241        expected=[3.0, 6.0, 9.0, 12.0, 15.0, 18.0, 21.0, 24.0])
242
243  def testKernelSmallerThanStrideValid(self):
244    self._VerifyValues(
245        nn_ops.max_pool,
246        input_sizes=[1, 7, 7, 1],
247        ksize=[1, 2, 2, 1],
248        strides=[1, 3, 3, 1],
249        padding="VALID",
250        expected=[9, 12, 30, 33])
251
252  def testKernelSmallerThanStrideSame(self):
253    self._VerifyValues(
254        nn_ops.max_pool,
255        input_sizes=[1, 3, 3, 1],
256        ksize=[1, 1, 1, 1],
257        strides=[1, 2, 2, 1],
258        padding="SAME",
259        expected=[1, 3, 7, 9])
260
261    self._VerifyValues(
262        nn_ops.max_pool,
263        input_sizes=[1, 4, 4, 1],
264        ksize=[1, 1, 1, 1],
265        strides=[1, 2, 2, 1],
266        padding="SAME",
267        expected=[1, 3, 9, 11])
268
269  # Average pooling
270  def testAvgPoolValidPadding(self):
271    expected_output = [7, 8, 9]
272    self._VerifyValues(
273        nn_ops.avg_pool,
274        input_sizes=[1, 3, 3, 3],
275        ksize=[1, 2, 2, 1],
276        strides=[1, 2, 2, 1],
277        padding="VALID",
278        expected=expected_output)
279
280  def testAvgPoolSamePadding(self):
281    expected_output = [7., 8., 9., 11.5, 12.5, 13.5]
282    self._VerifyValues(
283        nn_ops.avg_pool,
284        input_sizes=[1, 2, 3, 3],
285        ksize=[1, 2, 2, 1],
286        strides=[1, 2, 2, 1],
287        padding="SAME",
288        expected=expected_output)
289
290
291class PoolGradTest(xla_test.XLATestCase):
292
293  CPU_DEVICE = "/job:localhost/replica:0/task:0/cpu:0"
294
295  def _VerifyOneTest(self,
296                     pool_func,
297                     pool_grad_func,
298                     input_sizes,
299                     ksize,
300                     strides,
301                     padding,
302                     data_format,
303                     pool_grad_grad_func=None):
304    """Verifies the output values of the pooling gradient function.
305
306    Args:
307      pool_func: Forward pooling function
308      pool_grad_func: Pooling gradient function for pool_grad_func
309      input_sizes: Input tensor dimensions.
310      ksize: The kernel size dimensions
311      strides: The stride dimensions
312      padding: Padding type.
313      data_format: The data format we use to run the pooling operation.
314      pool_grad_grad_func: Second-order gradient function, if available.
315    """
316    total_size = np.prod(input_sizes)
317    # TODO(b/73062247): MaxPoolGradGrad can confuse gradients when x is equally
318    # maximal at 16 bits. Switch to np.random.randn when resolved.
319    x = np.arange(1, total_size + 1, dtype=np.float32)
320    x *= (np.random.randint(2, size=total_size) * 2 - 1)  # Flip signs randomly
321    # Verify some specifically interesting values...
322    x[np.random.choice(total_size)] = np.inf
323    x[np.random.choice(total_size)] = -np.inf
324    # TODO(b/74222344): Fix nan handling for max pool grad.
325    # x[np.random.choice(total_size)] = np.nan
326    x = x.reshape(input_sizes)
327    with self.cached_session() as sess:
328      # Use the forward pool function to compute some corresponding outputs
329      # (needed for the CPU device, and we need the shape in both cases).
330      with ops.device(self.CPU_DEVICE):
331        inputs = array_ops.placeholder(dtypes.float32, shape=input_sizes)
332        outputs = pool_func(
333            inputs,
334            ksize=ksize,
335            strides=strides,
336            padding=padding,
337            data_format="NHWC")
338
339      output_vals = np.array(sess.run(outputs, {inputs: x}))
340      output_gradient_vals = np.arange(
341          1, output_vals.size + 1, dtype=np.float32)
342      output_gradient_vals = output_gradient_vals.reshape(output_vals.shape)
343      output_grad_grad_vals = np.arange(1, x.size + 1, dtype=np.float32)
344      output_grad_grad_vals = output_grad_grad_vals.reshape(x.shape)
345
346      # Use the Tensorflow CPU pooling gradient to compute the expected input
347      # gradients.
348      with ops.device(self.CPU_DEVICE):
349        output_gradients = array_ops.placeholder(
350            dtypes.float32, shape=output_vals.shape)
351        expected_input_gradients = pool_grad_func(
352            inputs,
353            outputs,
354            output_gradients,
355            ksize=ksize,
356            strides=strides,
357            padding=padding,
358            data_format="NHWC")
359        expected_input_gradient_vals = sess.run(
360            expected_input_gradients,
361            {inputs: x,
362             output_gradients: output_gradient_vals})
363
364        output_grad_gradients = array_ops.placeholder(
365            dtypes.float32, shape=expected_input_gradient_vals.shape)
366        if pool_grad_grad_func is not None:
367          expected_grad_gradients = pool_grad_grad_func(
368              inputs,
369              outputs,
370              output_grad_gradients,
371              ksize=ksize,
372              strides=strides,
373              padding=padding,
374              data_format="NHWC")
375          expected_grad_gradients_vals = sess.run(expected_grad_gradients, {
376              inputs: x,
377              output_grad_gradients: output_grad_grad_vals
378          })
379
380      # Run the gradient op on the XLA device
381      with self.test_scope():
382        outputs = array_ops.placeholder(dtypes.float32, shape=output_vals.shape)
383        xla_inputs = inputs
384        xla_outputs = outputs
385        xla_output_gradients = output_gradients
386        xla_output_grad_gradients = output_grad_gradients
387        xla_ksize = ksize
388        xla_strides = strides
389        if data_format == "NCHW":
390          xla_inputs = NHWCToNCHW(inputs)
391          xla_outputs = NHWCToNCHW(outputs)
392          xla_output_gradients = NHWCToNCHW(output_gradients)
393          xla_output_grad_gradients = NHWCToNCHW(output_grad_gradients)
394          xla_ksize = NHWCToNCHW(ksize)
395          xla_strides = NHWCToNCHW(strides)
396        actual_input_gradients = pool_grad_func(
397            xla_inputs,
398            xla_outputs,
399            xla_output_gradients,
400            ksize=xla_ksize,
401            strides=xla_strides,
402            padding=padding,
403            data_format=data_format)
404        if data_format == "NCHW":
405          actual_input_gradients = NCHWToNHWC(actual_input_gradients)
406        if pool_grad_grad_func is not None:
407          actual_grad_gradients = pool_grad_grad_func(
408              xla_inputs,
409              xla_outputs,
410              xla_output_grad_gradients,
411              ksize=xla_ksize,
412              strides=xla_strides,
413              padding=padding,
414              data_format=data_format)
415          if data_format == "NCHW":
416            actual_grad_gradients = NCHWToNHWC(actual_grad_gradients)
417      actual_input_gradients_vals = sess.run(actual_input_gradients, {
418          inputs: x,
419          outputs: output_vals,
420          output_gradients: output_gradient_vals
421      })
422      # Compare the Tensorflow and XLA results.
423      self.assertAllClose(
424          expected_input_gradient_vals,
425          actual_input_gradients_vals,
426          rtol=1e-4,
427          atol=1e-6)
428      self.assertShapeEqual(actual_input_gradients_vals, inputs)
429
430      if pool_grad_grad_func is not None:
431        actual_grad_gradients_vals = sess.run(
432            actual_grad_gradients, {
433                inputs: x,
434                outputs: output_vals,
435                output_grad_gradients: output_grad_grad_vals
436            })
437
438        # Compare the Tensorflow and XLA results.
439        self.assertAllClose(
440            expected_grad_gradients_vals,
441            actual_grad_gradients_vals,
442            rtol=1e-4,
443            atol=1e-6)
444        self.assertShapeEqual(actual_grad_gradients_vals, outputs)
445
446  def _VerifyValues(self,
447                    pool_func,
448                    pool_grad_func,
449                    input_sizes,
450                    ksize,
451                    strides,
452                    padding,
453                    pool_grad_grad_func=None):
454    """Verifies the output values of the pooling function.
455
456    Args:
457      pool_func: Pooling function to be called, e.g., tf.nn.max_pool
458      pool_grad_func: Corresponding pooling gradient function.
459      input_sizes: Input tensor dimensions.
460      ksize: The kernel size dimensions
461      strides: The stride dimensions
462      padding: Padding type.
463      pool_grad_grad_func: Second-order gradient function, if available.
464    """
465    for data_format in GetTestConfigs():
466      self._VerifyOneTest(
467          pool_func,
468          pool_grad_func,
469          input_sizes,
470          ksize,
471          strides,
472          padding,
473          data_format,
474          pool_grad_grad_func=pool_grad_grad_func)
475
476  def _TestPooling(self, forward_op, backward_op, pool_grad_grad_func=None):
477    # VALID padding
478    self._VerifyValues(
479        forward_op,
480        backward_op,
481        input_sizes=[1, 3, 3, 3],
482        ksize=[1, 2, 2, 1],
483        strides=[1, 2, 2, 1],
484        padding="VALID",
485        pool_grad_grad_func=pool_grad_grad_func)
486
487    # SAME padding
488    self._VerifyValues(
489        forward_op,
490        backward_op,
491        input_sizes=[1, 2, 3, 3],
492        ksize=[1, 2, 2, 1],
493        strides=[1, 2, 2, 1],
494        padding="SAME",
495        pool_grad_grad_func=pool_grad_grad_func)
496
497    # SAME padding, non square window
498    self._VerifyValues(
499        forward_op,
500        backward_op,
501        input_sizes=[1, 2, 2, 1],
502        ksize=[1, 1, 2, 1],
503        strides=[1, 1, 1, 1],
504        padding="SAME",
505        pool_grad_grad_func=pool_grad_grad_func)
506
507    # VALID padding, uneven stride
508    self._VerifyValues(
509        forward_op,
510        backward_op,
511        input_sizes=[1, 4, 4, 1],
512        ksize=[1, 2, 2, 1],
513        strides=[1, 1, 2, 1],
514        padding="VALID",
515        pool_grad_grad_func=pool_grad_grad_func)
516    self._VerifyValues(
517        forward_op,
518        backward_op,
519        input_sizes=[1, 4, 4, 1],
520        ksize=[1, 2, 2, 1],
521        strides=[1, 2, 1, 1],
522        padding="VALID",
523        pool_grad_grad_func=pool_grad_grad_func)
524
525    # SAME padding, size 4 input
526    self._VerifyValues(
527        forward_op,
528        backward_op,
529        input_sizes=[1, 4, 4, 4],
530        ksize=[1, 2, 2, 1],
531        strides=[1, 2, 2, 1],
532        padding="SAME",
533        pool_grad_grad_func=pool_grad_grad_func)
534
535    # SAME padding, size 8 input
536    self._VerifyValues(
537        forward_op,
538        backward_op,
539        input_sizes=[1, 8, 8, 8],
540        ksize=[1, 3, 3, 1],
541        strides=[1, 2, 2, 1],
542        padding="SAME",
543        pool_grad_grad_func=pool_grad_grad_func)
544
545  def testMaxPool(self):
546    self._TestPooling(
547        nn_ops.max_pool,
548        gen_nn_ops.max_pool_grad,
549        pool_grad_grad_func=gen_nn_ops.max_pool_grad_grad)
550
551  def testAvgPool(self):
552    # Wrapper around AvgPoolGrad that ignores extra arguments needed by
553    # MaxPoolGrad.
554    def AvgPoolGrad(inputs, outputs, output_gradients, ksize, strides, padding,
555                    data_format):
556      del outputs  # Unused by average-pooling gradients.
557      return gen_nn_ops.avg_pool_grad(
558          inputs.get_shape().as_list(),
559          output_gradients,
560          ksize=ksize,
561          strides=strides,
562          padding=padding,
563          data_format=data_format)
564
565    self._TestPooling(nn_ops.avg_pool, AvgPoolGrad)
566
567  # The CPU implementation of AvgPoolGrad doesn't accept kernels smaller than
568  # the stride size, so we only run the following tests on MaxPoolGrad.
569
570  def testMaxPoolKernelSmallerThanStrideValid(self):
571    self._VerifyValues(
572        nn_ops.max_pool,
573        gen_nn_ops.max_pool_grad,
574        input_sizes=[1, 7, 7, 1],
575        ksize=[1, 2, 2, 1],
576        strides=[1, 3, 3, 1],
577        padding="VALID")
578
579  def testMaxPoolKernelSmallerThanStrideSame(self):
580    self._VerifyValues(
581        nn_ops.max_pool,
582        gen_nn_ops.max_pool_grad,
583        input_sizes=[1, 3, 3, 1],
584        ksize=[1, 1, 1, 1],
585        strides=[1, 2, 2, 1],
586        padding="SAME")
587
588    self._VerifyValues(
589        nn_ops.max_pool,
590        gen_nn_ops.max_pool_grad,
591        input_sizes=[1, 4, 4, 1],
592        ksize=[1, 1, 1, 1],
593        strides=[1, 2, 2, 1],
594        padding="SAME")
595
596
597if __name__ == "__main__":
598  googletest.main()
599