• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for fractional max pool operation."""
16
17import math
18
19import numpy as np
20
21from tensorflow.python.framework import constant_op
22from tensorflow.python.framework import dtypes
23from tensorflow.python.framework import errors
24from tensorflow.python.framework import test_util
25from tensorflow.python.ops import array_ops
26from tensorflow.python.ops import gen_nn_ops
27from tensorflow.python.ops import gradient_checker
28from tensorflow.python.ops import nn_ops
29import tensorflow.python.ops.nn_grad  # pylint: disable=unused-import
30from tensorflow.python.platform import test
31
32
33class FractionalMaxPoolTest(test.TestCase):
34
35  # Random number generate with seed.
36  _PRNG = np.random.RandomState(341261)
37  _SEED = 123456
38
39  def _MaxPoolAlongRows(self, input_matrix, row_seq, overlapping):
40    """Perform max pool along row of a 2-D matrix based on row_seq.
41
42    Args:
43      input_matrix: A 2-D matrix.
44      row_seq: Cumulative pooling sequence along row.
45      overlapping: Whether or not use overlapping when pooling.
46
47    Returns:
48      A 2-D matrix, with
49        * num_rows = len(row_seq)-1
50        * num_cols = input_matrix.num_cols.
51    """
52    output_image = np.zeros(input_matrix.shape[1])
53    row_max = row_seq[-1]
54    for i in range(row_seq.shape[0] - 1):
55      row_start = row_seq[i]
56      row_end = row_seq[i + 1] + 1 if overlapping else row_seq[i + 1]
57      row_end = min(row_end, row_max)
58      output_image = np.vstack((output_image, np.amax(
59          input_matrix[row_start:row_end, :], axis=0)))  # axis 0 is along row
60    # remove the sentinel row
61    return output_image[1:, :]
62
63  def _MaxPoolAlongCols(self, input_matrix, col_seq, overlapping):
64    """Perform max pool along column of a 2-D matrix based on col_seq.
65
66    Args:
67      input_matrix: A 2-D matrix.
68      col_seq: Cumulative pooling sequence along column.
69      overlapping: Whether or not use overlapping when pooling.
70
71    Returns:
72      A 2-D matrix, with
73        * num_rows = input_matrix.num_rows
74        * num_cols = len(col_seq)-1.
75    """
76    input_matrix = input_matrix.transpose()
77    output_matrix = self._MaxPoolAlongRows(input_matrix, col_seq, overlapping)
78    return output_matrix.transpose()
79
80  def _GetExpectedFractionalMaxPoolResult(self, input_tensor, row_seq, col_seq,
81                                          overlapping):
82    """Get expected fractional max pool result.
83
84    row_seq and col_seq together defines the fractional pooling region.
85
86    Args:
87      input_tensor: Original input tensor, assuming it is a 4-D tensor, with
88        dimension as [batch, height/row, width/column, channels/depth].
89      row_seq: Cumulative pooling sequence along row.
90      col_seq: Cumulative pooling sequence along column.
91      overlapping: Use overlapping when doing pooling.
92
93    Returns:
94      A 4-D tensor that is the result of max pooling on input_tensor based on
95        pooling region defined by row_seq and col_seq, conditioned on whether or
96        not overlapping is used.
97    """
98    input_shape = input_tensor.shape
99    output_shape = (input_shape[0], len(row_seq) - 1, len(col_seq) - 1,
100                    input_shape[3])
101    output_tensor = np.zeros(shape=output_shape, dtype=input_tensor.dtype)
102    for batch in range(input_shape[0]):
103      for channel in range(input_shape[3]):
104        two_dim_slice = input_tensor[batch, :, :, channel]
105        tmp = self._MaxPoolAlongRows(two_dim_slice, row_seq, overlapping)
106        output_tensor[batch, :, :, channel] = self._MaxPoolAlongCols(
107            tmp, col_seq, overlapping)
108
109    return output_tensor
110
111  def _ValidateFractionalMaxPoolResult(self, input_tensor, pooling_ratio,
112                                       pseudo_random, overlapping):
113    """Validate FractionalMaxPool's result against expected.
114
115    Expected result is computed given input_tensor, and pooling region defined
116    by row_seq and col_seq.
117
118    Args:
119      input_tensor: A tensor or numpy ndarray.
120      pooling_ratio: A list or tuple of length 4, first and last element be 1.
121      pseudo_random: Use pseudo random method to generate pooling sequence.
122      overlapping: Use overlapping when pooling.
123
124    Returns:
125      None
126    """
127    with self.cached_session():
128      p, r, c = nn_ops.fractional_max_pool_v2(
129          input_tensor,
130          pooling_ratio,
131          pseudo_random,
132          overlapping,
133          seed=self._SEED)
134      actual, row_seq, col_seq = self.evaluate([p, r, c])
135      expected = self._GetExpectedFractionalMaxPoolResult(input_tensor, row_seq,
136                                                          col_seq, overlapping)
137      self.assertShapeEqual(expected, p)
138      self.assertAllClose(expected, actual)
139
140  def _testVisually(self):
141    """Manual test by printing out intermediate result of a small random tensor.
142
143    Since _GetExpectedFractionalMaxPoolResult is 'automated', it feel safer to
144    have a test case that you can see what's happening.
145    This test will generate a small, random, int 2D matrix, and feed it to
146    FractionalMaxPool and _GetExpectedFractionalMaxPoolResult.
147    """
148    num_rows = 6
149    num_cols = 6
150    tensor_shape = (1, num_rows, num_cols, 1)
151    pseudo_random = False
152    for overlapping in True, False:
153      print("-" * 70)
154      print("Testing FractionalMaxPool with overlapping = {}".format(
155          overlapping))
156      rand_mat = self._PRNG.randint(10, size=tensor_shape)
157      pooling_ratio = [1, math.sqrt(2), math.sqrt(2), 1]
158      with self.cached_session():
159        p, r, c = nn_ops.fractional_max_pool_v2(
160            rand_mat,
161            pooling_ratio,
162            pseudo_random,
163            overlapping,
164            seed=self._SEED)
165        tensor_output, row_seq, col_seq = self.evaluate([p, r, c])
166        expected_result = self._GetExpectedFractionalMaxPoolResult(rand_mat,
167                                                                   row_seq,
168                                                                   col_seq,
169                                                                   overlapping)
170        print("row sequence:")
171        print(row_seq)
172        print("column sequence:")
173        print(col_seq)
174
175        print("Input:")
176        # Print input with pooling region marked.
177        for i in range(num_rows):
178          row_to_print = []
179          for j in range(num_cols):
180            if j in col_seq:
181              row_to_print.append("|")
182            row_to_print.append(str(rand_mat[0, i, j, 0]))
183          row_to_print.append("|")
184          if i in row_seq:
185            print("-" * 2 * len(row_to_print))
186          print(" ".join(row_to_print))
187        print("-" * 2 * len(row_to_print))
188
189        print("Output from FractionalMaxPool:")
190        print(tensor_output[0, :, :, 0])
191        print("Expected result:")
192        print(expected_result[0, :, :, 0])
193
194  def testAllInputOptions(self):
195    """Try all possible input options for fractional_max_pool.
196    """
197    num_batches = 5
198    num_channels = 3
199    num_rows = 20
200    num_cols = 30
201    for pseudo_random in True, False:
202      for overlapping in True, False:
203        tensor_shape = (num_batches, num_rows, num_cols, num_channels)
204        # random tensor with value in [-500.0, 500.0)
205        rand_mat = self._PRNG.random_sample(tensor_shape) * 1000 - 500
206        self._ValidateFractionalMaxPoolResult(
207            rand_mat, [1, math.sqrt(3), math.sqrt(2), 1], pseudo_random,
208            overlapping)
209
210  def testIntegerTensorInput(self):
211    """Test it works fine when input tensor is integer type.
212    """
213    num_batches = 5
214    num_channels = 3
215    num_rows = 20
216    num_cols = 30
217    pseudo_random = True
218    overlapping = True
219    tensor_shape = (num_batches, num_rows, num_cols, num_channels)
220    rand_mat = self._PRNG.randint(1000, size=tensor_shape)
221    self._ValidateFractionalMaxPoolResult(rand_mat,
222                                          [1, math.sqrt(3), math.sqrt(2), 1],
223                                          pseudo_random, overlapping)
224
225  def testDifferentTensorShapes(self):
226    """Test different shapes of input tensor.
227
228    Mainly test different combinations of num_rows and num_cols.
229    """
230    pseudo_random = True
231    overlapping = True
232    for num_batches in [1, 3]:
233      for num_channels in [1, 3]:
234        for num_rows in [10, 20, 50]:
235          for num_cols in [10, 20, 50]:
236            tensor_shape = (num_batches, num_rows, num_cols, num_channels)
237            # random tensor with value in [-500.0, 500.0)
238            rand_mat = self._PRNG.random_sample(tensor_shape) * 1000 - 500
239            self._ValidateFractionalMaxPoolResult(
240                rand_mat, [1, math.sqrt(3), math.sqrt(2), 1], pseudo_random,
241                overlapping)
242
243  def testLargePoolingRatio(self):
244    """Test when pooling ratio is not within [1, 2).
245    """
246    pseudo_random = True
247    overlapping = True
248    num_batches = 3
249    num_channels = 3
250    num_rows = 30
251    num_cols = 50
252    tensor_shape = (num_batches, num_rows, num_cols, num_channels)
253    for row_ratio in [math.sqrt(11), math.sqrt(37)]:
254      for col_ratio in [math.sqrt(11), math.sqrt(27)]:
255        # random tensor with value in [-500.0, 500.0)
256        rand_mat = self._PRNG.random_sample(tensor_shape) * 1000 - 500
257        self._ValidateFractionalMaxPoolResult(rand_mat,
258                                              [1, row_ratio, col_ratio, 1],
259                                              pseudo_random, overlapping)
260
261  def testDivisiblePoolingRatio(self):
262    """Test when num of rows/cols can evenly divide pooling ratio.
263
264    This is a case regular max pooling can handle. Should be handled by
265    fractional pooling as well.
266    """
267    pseudo_random = True
268    overlapping = True
269    num_batches = 3
270    num_channels = 3
271    num_rows = 30
272    num_cols = 50
273    tensor_shape = (num_batches, num_rows, num_cols, num_channels)
274    # random tensor with value in [-500.0, 500.0)
275    rand_mat = self._PRNG.random_sample(tensor_shape) * 1000 - 500
276    self._ValidateFractionalMaxPoolResult(rand_mat, [1, 2, 2, 1], pseudo_random,
277                                          overlapping)
278
279  @test_util.run_deprecated_v1
280  def testDifferentInputTensorShape(self):
281    """Runs the operation in one session with different input tensor shapes."""
282    with self.cached_session() as sess:
283      input_holder = array_ops.placeholder(dtypes.float32,
284                                           [None, None, None, 3])
285      pooling_ratio = [1, 1.5, 1.5, 1]
286      pseudo_random = False
287      overlapping = False
288      p, r, c = nn_ops.fractional_max_pool_v2(
289          input_holder,
290          pooling_ratio,
291          pseudo_random,
292          overlapping,
293          seed=self._SEED)
294      # First run.
295      input_a = np.zeros([3, 32, 32, 3])
296      actual, row_seq, col_seq = sess.run([p, r, c], {input_holder: input_a})
297      expected = self._GetExpectedFractionalMaxPoolResult(
298          input_a, row_seq, col_seq, overlapping)
299      self.assertSequenceEqual(expected.shape, actual.shape)
300      # Second run.
301      input_b = np.zeros([4, 45, 45, 3])
302      actual, row_seq, col_seq = sess.run([p, r, c], {input_holder: input_b})
303      expected = self._GetExpectedFractionalMaxPoolResult(
304          input_b, row_seq, col_seq, overlapping)
305      self.assertSequenceEqual(expected.shape, actual.shape)
306
307  def testDeterminismExceptionThrowing(self):
308    tensor_shape = (5, 20, 20, 3)
309    rand_mat = self._PRNG.random_sample(tensor_shape) * 1000 - 500
310    with test_util.deterministic_ops():
311      with self.assertRaisesRegex(
312          ValueError, "requires a non-zero seed to be passed in when "
313          "determinism is enabled"):
314        nn_ops.fractional_max_pool_v2(rand_mat, [1, 1.5, 1.5, 1])
315      nn_ops.fractional_max_pool_v2(rand_mat, [1, 1.5, 1.5, 1], seed=1)
316
317      with self.assertRaisesRegex(ValueError,
318                                  'requires "seed" and "seed2" to be non-zero'):
319        nn_ops.fractional_max_pool(rand_mat, [1, 1.5, 1.5, 1])
320      nn_ops.fractional_max_pool(
321          rand_mat, [1, 1.5, 1.5, 1], seed=1, seed2=1, deterministic=True)
322
323  def testPoolingRatio(self):
324    with self.cached_session() as _:
325      with self.assertRaisesRegex(
326          errors.InvalidArgumentError,
327          r"Pooling ratio is higher than input dimension size for dimension 1.*"
328      ):
329        result = nn_ops.gen_nn_ops.fractional_max_pool(
330            value=constant_op.constant(
331                value=[[[[1, 4, 2, 3]]]], dtype=dtypes.int64),
332            pooling_ratio=[1.0, 1.44, 1.73, 1.0],
333            pseudo_random=False,
334            overlapping=False,
335            deterministic=False,
336            seed=0,
337            seed2=0,
338            name=None)
339        self.evaluate(result)
340
341
342class FractionalMaxPoolGradTest(test.TestCase):
343  """Tests for FractionalMaxPoolGrad.
344
345  Two types of tests for FractionalMaxPoolGrad.
346  1) Test fractional_max_pool_grad() directly.
347    This type of test relies on gen_nn_ops.max_pool_grad() returns the correct
348  result. For example:
349    * input_tensor_shape = (1, 10, 10, 1)
350    * window_size = (1, 2, 2, 1)
351    * stride_size = (1, 2, 2, 1)
352    * padding: not really import, since 10/2 is divisible
353  max pooling should generate the same result as fractional max pooling with:
354    * row_sequence = [0, 2, 4, 6, 8, 10]
355    * col_sequence = [0, 2, 4, 6, 8, 10]
356    * overlapping = False
357  This also means their gradients in such case will be the same.
358
359    Similarly, when
360    * input_tensor_shape = (1, 7, 7, 1)
361    * window_size = (1, 3, 3, 1)
362    * stride_size = (1, 2, 2, 1)
363    * padding: not important
364  max pooling should generate the same result as fractional max pooling with:
365    * row_sequence = [0, 2, 4, 7]
366    * col_sequence = [0, 2, 4, 7]
367    * overlapping = True
368  2) Test through compute_gradient_error()
369  """
370
371  _PRNG = np.random.RandomState(341261)
372  _SEED = 123456
373
374  def _GenerateUniqueRandomInputTensor(self, shape):
375    """Generate 'unique' random input tensor.
376
377    'Unique' means there's no collision values in the tensor, all elements are
378    different. This is done by generating sequence of integers with step of 1
379    and then randomly shuffle these integers.
380
381    Args:
382      shape: Shape of the tensor desired.
383
384    Returns:
385      A numpy ndarray with size = shape and dtype = numpy.float32.
386    """
387    num_elements = 1
388    for size in shape:
389      num_elements *= size
390    x = np.arange(num_elements, dtype=np.float32)
391    self._PRNG.shuffle(x)
392    return x.reshape(shape)
393
394  def testDirectNotUseOverlapping(self):
395    for num_batches in [1, 3]:
396      for row_window_size in [2, 5]:
397        for col_window_size in [2, 4]:
398          num_rows = row_window_size * 5
399          num_cols = col_window_size * 7
400          for num_channels in [1, 2]:
401            input_shape = (num_batches, num_rows, num_cols, num_channels)
402            with self.cached_session() as _:
403              input_tensor = constant_op.constant(
404                  self._GenerateUniqueRandomInputTensor(input_shape))
405              window_size = [1, row_window_size, col_window_size, 1]
406              stride_size = [1, row_window_size, col_window_size, 1]
407              padding = "VALID"
408              output_tensor = nn_ops.max_pool(input_tensor, window_size,
409                                              stride_size, padding)
410              output_data = self.evaluate(output_tensor)
411              output_backprop = self._PRNG.randint(100, size=output_data.shape)
412              input_backprop_tensor = gen_nn_ops.max_pool_grad(
413                  input_tensor, output_tensor, output_backprop, window_size,
414                  stride_size, padding)
415              input_backprop = self.evaluate(input_backprop_tensor)
416              row_seq = list(range(0, num_rows + 1, row_window_size))
417              col_seq = list(range(0, num_cols + 1, col_window_size))
418              fmp_input_backprop_tensor = gen_nn_ops.fractional_max_pool_grad(
419                  input_tensor,
420                  output_tensor,
421                  output_backprop,
422                  row_seq,
423                  col_seq,
424                  overlapping=False)
425              fmp_input_backprop = self.evaluate(fmp_input_backprop_tensor)
426              self.assertShapeEqual(input_backprop, fmp_input_backprop_tensor)
427              self.assertAllClose(input_backprop, fmp_input_backprop)
428
429  def testDirectUseOverlapping(self):
430    for num_batches in [1, 3]:
431      for row_window_size in [2, 5]:
432        for col_window_size in [2, 4]:
433          num_rows = (row_window_size - 1) * 5 + 1
434          num_cols = (col_window_size - 1) * 7 + 1
435          for num_channels in [1, 2]:
436            input_shape = (num_batches, num_rows, num_cols, num_channels)
437            with self.cached_session() as _:
438              input_tensor = constant_op.constant(
439                  self._GenerateUniqueRandomInputTensor(input_shape))
440              window_size = [1, row_window_size, col_window_size, 1]
441              stride_size = [1, row_window_size - 1, col_window_size - 1, 1]
442              padding = "VALID"
443              output_tensor = nn_ops.max_pool(input_tensor, window_size,
444                                              stride_size, padding)
445              output_data = self.evaluate(output_tensor)
446              output_backprop = self._PRNG.randint(100, size=output_data.shape)
447              input_backprop_tensor = gen_nn_ops.max_pool_grad(
448                  input_tensor, output_tensor, output_backprop, window_size,
449                  stride_size, padding)
450              input_backprop = self.evaluate(input_backprop_tensor)
451              row_seq = list(range(0, num_rows, row_window_size - 1))
452              col_seq = list(range(0, num_cols, col_window_size - 1))
453              row_seq[-1] += 1
454              col_seq[-1] += 1
455              fmp_input_backprop_tensor = gen_nn_ops.fractional_max_pool_grad(
456                  input_tensor,
457                  output_tensor,
458                  output_backprop,
459                  row_seq,
460                  col_seq,
461                  overlapping=True)
462              fmp_input_backprop = self.evaluate(fmp_input_backprop_tensor)
463              self.assertShapeEqual(input_backprop, fmp_input_backprop_tensor)
464              self.assertAllClose(input_backprop, fmp_input_backprop)
465
466  @test_util.run_deprecated_v1
467  def testAllInputOptionsThroughGradientError(self):
468    input_shape = (1, 7, 13, 1)
469    input_data = self._GenerateUniqueRandomInputTensor(input_shape)
470    # Add some randomness to make input_data not so 'integer'
471    input_data += self._PRNG.random_sample(input_shape)
472    pooling_ratio = [1, math.sqrt(2), math.sqrt(3), 1]
473
474    for pseudo_random in True, False:
475      for overlapping in True, False:
476        with self.cached_session() as _:
477          input_tensor = constant_op.constant(input_data, shape=input_shape)
478          output_tensor, unused_a, unused_b = nn_ops.fractional_max_pool_v2(
479              input_tensor,
480              pooling_ratio,
481              pseudo_random=pseudo_random,
482              overlapping=overlapping,
483              seed=self._SEED)
484          output_data = self.evaluate(output_tensor)
485          output_shape = output_data.shape
486          # error_margin and delta setting is similar to max_pool_grad.
487          error_margin = 1e-3
488          gradient_error = gradient_checker.compute_gradient_error(
489              input_tensor,
490              input_shape,
491              output_tensor,
492              output_shape,
493              x_init_value=input_data.reshape(input_shape),
494              delta=1e-2)
495          self.assertLess(gradient_error, error_margin)
496
497  @test_util.run_deprecated_v1
498  def testDifferentTensorShapesThroughGradientError(self):
499    pseudo_random = True
500    overlapping = True
501    pooling_ratio = [1, math.sqrt(3), math.sqrt(2), 1]
502    for num_batches in [1, 2]:
503      for num_rows in [5, 13]:
504        for num_cols in [5, 11]:
505          for num_channels in [1, 3]:
506            input_shape = (num_batches, num_rows, num_cols, num_channels)
507            input_data = self._GenerateUniqueRandomInputTensor(input_shape)
508            # Add some randomness to make input_data not so 'integer'
509            input_data += self._PRNG.random_sample(input_shape)
510            with self.cached_session() as _:
511              input_tensor = constant_op.constant(input_data, shape=input_shape)
512              output_tensor, unused_a, unused_b = nn_ops.fractional_max_pool_v2(
513                  input_tensor,
514                  pooling_ratio,
515                  pseudo_random=pseudo_random,
516                  overlapping=overlapping,
517                  seed=self._SEED)
518              output_data = self.evaluate(output_tensor)
519              output_shape = output_data.shape
520              # error_margin and delta setting is similar to max_pool_grad.
521              error_margin = 1e-3
522              gradient_error = gradient_checker.compute_gradient_error(
523                  input_tensor,
524                  input_shape,
525                  output_tensor,
526                  output_shape,
527                  x_init_value=input_data.reshape(input_shape),
528                  delta=1e-2)
529              self.assertLess(gradient_error, error_margin)
530
531  @test_util.run_deprecated_v1
532  def testLargePoolingRatioThroughGradientError(self):
533    input_shape = (1, 17, 23, 1)
534    input_data = self._GenerateUniqueRandomInputTensor(input_shape)
535    # Add some randomness to make input_data not so 'integer'
536    input_data += self._PRNG.random_sample(input_shape)
537    pooling_ratio = (1, math.sqrt(13), math.sqrt(7), 1)
538    output_shape = [int(a / b) for a, b in zip(input_shape, pooling_ratio)]
539    overlapping = True
540    pseudo_random = False
541
542    with self.cached_session() as _:
543      input_tensor = constant_op.constant(input_data, shape=input_shape)
544      output_tensor, unused_a, unused_b = nn_ops.fractional_max_pool_v2(
545          input_tensor,
546          pooling_ratio,
547          pseudo_random=pseudo_random,
548          overlapping=overlapping,
549          seed=self._SEED)
550      # error_margin and delta setting is similar to max_pool_grad.
551      error_margin = 1e-3
552      gradient_error = gradient_checker.compute_gradient_error(
553          input_tensor,
554          input_shape,
555          output_tensor,
556          output_shape,
557          x_init_value=input_data.reshape(input_shape),
558          delta=1e-2)
559      self.assertLess(gradient_error, error_margin)
560
561  def testWhenRepeatedMaxValueInPoolingRegion(self):
562    """Test when there's repeating value in pooling region.
563
564    There's no formal definition for what the gradient should be when there're
565    multiple max value within a pooling cell. Such as
566        | 1 5 |
567        | 5 3 |
568    The expected result depends heavily on implementation, if someone swap the
569    order of a nested for loop when walking through the tensor, result would be
570    very different.
571
572    The goal of this test is to alert when someone else change the
573    implementation. Current implementation scans row-by-row.
574    """
575    input_data = [5.0, 4.0, 6.0, 7.0,
576                  3.0, 5.0, 9.0, 6.0,
577                  8.0, 8.0, 9.0, 5.0,
578                  7.0, 4.0, 0.0, 0.0]  # pyformat: disable
579    input_size = [1, 4, 4, 1]
580    output_backprop = [12.0, 15.0,
581                       17.0, -5.0,
582                       6.0, 21.0]  # pyformat: disable
583    row_seq = [0, 1, 3, 4]
584    col_seq = [0, 2, 4]
585    output_data_not_overlapping = [5.0, 7.0,
586                                   8.0, 9.0,
587                                   7.0, 0.0]  # pyformat: disable
588    output_data_overlapping = [9.0, 9.0,
589                               9.0, 9.0,
590                               7.0, 0.0]  # pyformat: disable
591    output_size = [1, 3, 2, 1]
592    expected_input_backprop_not_overlapping = np.reshape(
593        [12.0, 0.0, 0.0, 15.0,
594         0.0, 0.0, -5.0, 0.0,
595         17.0, 0.0, 0.0, 0.0,
596         6.0, 0.0, 21.0, 0.0],
597        input_size)  # pyformat: disable
598    expected_input_backprop_overlapping = np.reshape(
599        [0.0, 0.0, 0.0, 0.0,
600         0.0, 0.0, 39.0, 0.0,
601         0.0, 0.0, 0.0, 0.0,
602         6.0, 0.0, 21.0, 0.0],
603        input_size)  # pyformat: disable
604    with self.cached_session() as _:
605      # Test when overlapping is False
606      input_tensor = constant_op.constant(input_data, shape=input_size)
607      output_tensor = constant_op.constant(
608          output_data_not_overlapping, shape=output_size)
609      grad = constant_op.constant(output_backprop, shape=output_size)
610      r = gen_nn_ops.fractional_max_pool_grad(
611          input_tensor,
612          output_tensor,
613          grad,
614          row_seq,
615          col_seq,
616          overlapping=False)
617      input_backprop_not_overlapping = self.evaluate(r)
618      self.assertShapeEqual(
619          np.reshape(expected_input_backprop_not_overlapping, input_size), r)
620      self.assertAllClose(expected_input_backprop_not_overlapping,
621                          input_backprop_not_overlapping)
622      # Test when overlapping is True
623      output_tensor = constant_op.constant(
624          output_data_overlapping, shape=output_size)
625      r = gen_nn_ops.fractional_max_pool_grad(
626          input_tensor, output_tensor, grad, row_seq, col_seq, overlapping=True)
627      input_backprop_overlapping = self.evaluate(r)
628      self.assertShapeEqual(
629          np.reshape(expected_input_backprop_overlapping, input_size), r)
630      self.assertAllClose(expected_input_backprop_overlapping,
631                          input_backprop_overlapping)
632
633  def testInvalidSeqRaiseErrorForFractionalMaxPoolGrad(self):
634    with self.assertRaises(errors.InvalidArgumentError):
635      with self.cached_session() as _:
636        overlapping = True
637        orig_input = constant_op.constant(
638            .453409232, shape=[1, 7, 13, 1], dtype=dtypes.float32)
639        orig_output = constant_op.constant(
640            .453409232, shape=[1, 7, 13, 1], dtype=dtypes.float32)
641        out_backprop = constant_op.constant(
642            .453409232, shape=[1, 7, 13, 1], dtype=dtypes.float32)
643        row_pooling_sequence = constant_op.constant(
644            0, shape=[5], dtype=dtypes.int64)
645        col_pooling_sequence = constant_op.constant(
646            0, shape=[5], dtype=dtypes.int64)
647        t = gen_nn_ops.FractionalMaxPoolGrad(
648            orig_input=orig_input,
649            orig_output=orig_output,
650            out_backprop=out_backprop,
651            row_pooling_sequence=row_pooling_sequence,
652            col_pooling_sequence=col_pooling_sequence,
653            overlapping=overlapping)
654        self.evaluate(t)
655
656
657if __name__ == "__main__":
658  test.main()
659