1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Functional tests for pooling operations.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import numpy as np 22 23from tensorflow.compiler.tests import xla_test 24from tensorflow.python.framework import dtypes 25from tensorflow.python.framework import ops 26from tensorflow.python.ops import array_ops 27from tensorflow.python.ops import gen_nn_ops 28from tensorflow.python.ops import nn_ops 29from tensorflow.python.platform import googletest 30 31 32def NHWCToNCHW(input_tensor): 33 """Convert the input from NHWC format to NCHW. 34 35 Args: 36 input_tensor: a 4-D tensor, or a 4-element array representing the same. 37 38 Returns: 39 the converted tensor or a shape array 40 """ 41 if isinstance(input_tensor, ops.Tensor): 42 return array_ops.transpose(input_tensor, [0, 3, 1, 2]) 43 else: 44 return [input_tensor[0], input_tensor[3], input_tensor[1], input_tensor[2]] 45 46 47def NCHWToNHWC(input_tensor): 48 """Convert the input from NCHW format to NHWC. 49 50 Args: 51 input_tensor: a 4-D tensor, or a 4-element array representing the same. 52 53 Returns: 54 the converted tensor or a shape array 55 """ 56 if isinstance(input_tensor, ops.Tensor): 57 return array_ops.transpose(input_tensor, [0, 2, 3, 1]) 58 else: 59 return [input_tensor[0], input_tensor[2], input_tensor[3], input_tensor[1]] 60 61 62def GetTestConfigs(): 63 """Get all the valid tests configs to run. 64 65 Returns: 66 all the valid test configs 67 """ 68 test_configs = ["NHWC", "NCHW"] 69 return test_configs 70 71 72class PoolingTest(xla_test.XLATestCase): 73 74 def _VerifyOneTest(self, pool_func, input_sizes, ksize, strides, padding, 75 data_format, expected): 76 """Verifies the output values of the pooling function. 77 78 Args: 79 pool_func: Function to be called, currently only co.MaxPool. 80 input_sizes: Input tensor dimensions. 81 ksize: The kernel size dimensions 82 strides: The stride dimensions 83 padding: Padding type. 84 data_format: The data format we use to run the pooling operation. 85 expected: An array containing the expected operation outputs. 86 """ 87 total_size = np.prod(input_sizes) 88 # Initializes the input tensor with array containing incrementing 89 # numbers from 1. 90 x = np.array([f * 1.0 for f in range(1, total_size + 1)], dtype=np.float32) 91 x = x.reshape(input_sizes) 92 with self.cached_session() as sess: 93 with self.test_scope(): 94 inputs = array_ops.placeholder(dtypes.float32) 95 t = inputs 96 if data_format == "NCHW": 97 t = NHWCToNCHW(t) 98 ksize = NHWCToNCHW(ksize) 99 strides = NHWCToNCHW(strides) 100 t = pool_func(t, 101 ksize=ksize, 102 strides=strides, 103 padding=padding, 104 data_format=data_format) 105 if data_format == "NCHW": 106 t = NCHWToNHWC(t) 107 actual = sess.run(t, {inputs: x}) 108 self.assertAllClose(expected, actual.flatten(), rtol=1e-5, atol=1e-6) 109 110 def _VerifyValues(self, pool_func, input_sizes, ksize, strides, padding, 111 expected): 112 """Verifies the output values of the pooling function. 113 114 Args: 115 pool_func: Function to be called, co.MaxPool, co.AvgPool, 116 or the Lua version. 117 input_sizes: Input tensor dimensions. 118 ksize: The kernel size dimensions 119 strides: The stride dimensions 120 padding: Padding type. 121 expected: An array containing the expected operation outputs. 122 """ 123 for data_format in GetTestConfigs(): 124 self._VerifyOneTest(pool_func, input_sizes, ksize, strides, padding, 125 data_format, expected) 126 127 def testMaxPoolValidPadding(self): 128 expected_output = [13.0, 14.0, 15.0] 129 self._VerifyValues(nn_ops.max_pool, 130 input_sizes=[1, 3, 3, 3], 131 ksize=[1, 2, 2, 1], 132 strides=[1, 2, 2, 1], 133 padding="VALID", 134 expected=expected_output) 135 136 def testMaxPoolSamePadding(self): 137 expected_output = [13.0, 14.0, 15.0, 16.0, 17.0, 18.0] 138 self._VerifyValues(nn_ops.max_pool, 139 input_sizes=[1, 2, 3, 3], 140 ksize=[1, 2, 2, 1], 141 strides=[1, 2, 2, 1], 142 padding="SAME", 143 expected=expected_output) 144 145 def testMaxPoolSamePaddingNonSquareWindow(self): 146 # input is: 147 # [1.0, 2.0 148 # 3.0 4.0] 149 # 150 # Window of [x, x] should do: 151 # 152 # [max(1.0, 2.0), max(2.0, padded0), 153 # max(3.0, 4.0), max(4.0, padded0)] 154 self._VerifyValues( 155 nn_ops.max_pool, 156 input_sizes=[1, 2, 2, 1], 157 ksize=[1, 1, 2, 1], 158 strides=[1, 1, 1, 1], 159 padding="SAME", 160 expected=[2.0, 2.0, 4.0, 4.0]) 161 162 def testMaxPoolValidPaddingUnevenStride(self): 163 self._VerifyValues( 164 nn_ops.max_pool, 165 input_sizes=[1, 4, 4, 1], 166 ksize=[1, 2, 2, 1], 167 strides=[1, 1, 2, 1], 168 padding="VALID", 169 expected=[6.0, 8.0, 10.0, 12.0, 14.0, 16.0]) 170 self._VerifyValues( 171 nn_ops.max_pool, 172 input_sizes=[1, 4, 4, 1], 173 ksize=[1, 2, 2, 1], 174 strides=[1, 2, 1, 1], 175 padding="VALID", 176 expected=[6.0, 7.0, 8.0, 14.0, 15.0, 16.0]) 177 178 def testMaxPoolSamePaddingFilter4(self): 179 expected_output = [ 180 21.0, 22.0, 23.0, 24.0, 29.0, 30.0, 31.0, 32.0, 53.0, 54.0, 55.0, 56.0, 181 61.0, 62.0, 63.0, 64.0 182 ] 183 self._VerifyValues( 184 nn_ops.max_pool, 185 input_sizes=[1, 4, 4, 4], 186 ksize=[1, 2, 2, 1], 187 strides=[1, 2, 2, 1], 188 padding="SAME", 189 expected=expected_output) 190 191 def testMaxPoolSamePaddingFilter8(self): 192 expected_output = [ 193 145.0, 146.0, 147.0, 148.0, 149.0, 150.0, 151.0, 152.0, 161.0, 162.0, 194 163.0, 164.0, 165.0, 166.0, 167.0, 168.0, 177.0, 178.0, 179.0, 180.0, 195 181.0, 182.0, 183.0, 184.0, 185.0, 186.0, 187.0, 188.0, 189.0, 190.0, 196 191.0, 192.0, 273.0, 274.0, 275.0, 276.0, 277.0, 278.0, 279.0, 280.0, 197 289.0, 290.0, 291.0, 292.0, 293.0, 294.0, 295.0, 296.0, 305.0, 306.0, 198 307.0, 308.0, 309.0, 310.0, 311.0, 312.0, 313.0, 314.0, 315.0, 316.0, 199 317.0, 318.0, 319.0, 320.0, 401.0, 402.0, 403.0, 404.0, 405.0, 406.0, 200 407.0, 408.0, 417.0, 418.0, 419.0, 420.0, 421.0, 422.0, 423.0, 424.0, 201 433.0, 434.0, 435.0, 436.0, 437.0, 438.0, 439.0, 440.0, 441.0, 442.0, 202 443.0, 444.0, 445.0, 446.0, 447.0, 448.0, 465.0, 466.0, 467.0, 468.0, 203 469.0, 470.0, 471.0, 472.0, 481.0, 482.0, 483.0, 484.0, 485.0, 486.0, 204 487.0, 488.0, 497.0, 498.0, 499.0, 500.0, 501.0, 502.0, 503.0, 504.0, 205 505.0, 506.0, 507.0, 508.0, 509.0, 510.0, 511.0, 512.0 206 ] 207 self._VerifyValues( 208 nn_ops.max_pool, 209 input_sizes=[1, 8, 8, 8], 210 ksize=[1, 3, 3, 1], 211 strides=[1, 2, 2, 1], 212 padding="SAME", 213 expected=expected_output) 214 215 # Tests for DepthwiseMaxPooling on CPU only. 216 def testDepthwiseMaxPool1x1DepthWindow1(self): 217 # input is: 218 # [1.0, ..., 10.0] along depth, 219 # 220 # We maxpool by depth in patches of 2. 221 self._VerifyValues( 222 nn_ops.max_pool, 223 input_sizes=[1, 1, 1, 10], 224 ksize=[1, 1, 1, 2], 225 strides=[1, 1, 1, 2], 226 padding="SAME", 227 expected=[2.0, 4.0, 6.0, 8.0, 10.0]) 228 229 def testDepthwiseMaxPool2x2DepthWindow3(self): 230 # input is: 231 # 232 # a 2x2x6 cube, and we depthwise max across 3 to produce a 2x2x2 233 # output. Each node has contiguous values, so the depthwise max 234 # should be multiples of 3.0. 235 self._VerifyValues( 236 nn_ops.max_pool, 237 input_sizes=[1, 2, 2, 6], 238 ksize=[1, 1, 1, 3], 239 strides=[1, 1, 1, 3], 240 padding="SAME", 241 expected=[3.0, 6.0, 9.0, 12.0, 15.0, 18.0, 21.0, 24.0]) 242 243 def testKernelSmallerThanStrideValid(self): 244 self._VerifyValues( 245 nn_ops.max_pool, 246 input_sizes=[1, 7, 7, 1], 247 ksize=[1, 2, 2, 1], 248 strides=[1, 3, 3, 1], 249 padding="VALID", 250 expected=[9, 12, 30, 33]) 251 252 def testKernelSmallerThanStrideSame(self): 253 self._VerifyValues( 254 nn_ops.max_pool, 255 input_sizes=[1, 3, 3, 1], 256 ksize=[1, 1, 1, 1], 257 strides=[1, 2, 2, 1], 258 padding="SAME", 259 expected=[1, 3, 7, 9]) 260 261 self._VerifyValues( 262 nn_ops.max_pool, 263 input_sizes=[1, 4, 4, 1], 264 ksize=[1, 1, 1, 1], 265 strides=[1, 2, 2, 1], 266 padding="SAME", 267 expected=[1, 3, 9, 11]) 268 269 # Average pooling 270 def testAvgPoolValidPadding(self): 271 expected_output = [7, 8, 9] 272 self._VerifyValues( 273 nn_ops.avg_pool, 274 input_sizes=[1, 3, 3, 3], 275 ksize=[1, 2, 2, 1], 276 strides=[1, 2, 2, 1], 277 padding="VALID", 278 expected=expected_output) 279 280 def testAvgPoolSamePadding(self): 281 expected_output = [7., 8., 9., 11.5, 12.5, 13.5] 282 self._VerifyValues( 283 nn_ops.avg_pool, 284 input_sizes=[1, 2, 3, 3], 285 ksize=[1, 2, 2, 1], 286 strides=[1, 2, 2, 1], 287 padding="SAME", 288 expected=expected_output) 289 290 291class PoolGradTest(xla_test.XLATestCase): 292 293 CPU_DEVICE = "/job:localhost/replica:0/task:0/cpu:0" 294 295 def _VerifyOneTest(self, 296 pool_func, 297 pool_grad_func, 298 input_sizes, 299 ksize, 300 strides, 301 padding, 302 data_format, 303 pool_grad_grad_func=None): 304 """Verifies the output values of the pooling gradient function. 305 306 Args: 307 pool_func: Forward pooling function 308 pool_grad_func: Pooling gradient function for pool_grad_func 309 input_sizes: Input tensor dimensions. 310 ksize: The kernel size dimensions 311 strides: The stride dimensions 312 padding: Padding type. 313 data_format: The data format we use to run the pooling operation. 314 pool_grad_grad_func: Second-order gradient function, if available. 315 """ 316 total_size = np.prod(input_sizes) 317 # TODO(b/73062247): MaxPoolGradGrad can confuse gradients when x is equally 318 # maximal at 16 bits. Switch to np.random.randn when resolved. 319 x = np.arange(1, total_size + 1, dtype=np.float32) 320 x *= (np.random.randint(2, size=total_size) * 2 - 1) # Flip signs randomly 321 # Verify some specifically interesting values... 322 x[np.random.choice(total_size)] = np.inf 323 x[np.random.choice(total_size)] = -np.inf 324 # TODO(b/74222344): Fix nan handling for max pool grad. 325 # x[np.random.choice(total_size)] = np.nan 326 x = x.reshape(input_sizes) 327 with self.cached_session() as sess: 328 # Use the forward pool function to compute some corresponding outputs 329 # (needed for the CPU device, and we need the shape in both cases). 330 with ops.device(self.CPU_DEVICE): 331 inputs = array_ops.placeholder(dtypes.float32, shape=input_sizes) 332 outputs = pool_func( 333 inputs, 334 ksize=ksize, 335 strides=strides, 336 padding=padding, 337 data_format="NHWC") 338 339 output_vals = np.array(sess.run(outputs, {inputs: x})) 340 output_gradient_vals = np.arange( 341 1, output_vals.size + 1, dtype=np.float32) 342 output_gradient_vals = output_gradient_vals.reshape(output_vals.shape) 343 output_grad_grad_vals = np.arange(1, x.size + 1, dtype=np.float32) 344 output_grad_grad_vals = output_grad_grad_vals.reshape(x.shape) 345 346 # Use the Tensorflow CPU pooling gradient to compute the expected input 347 # gradients. 348 with ops.device(self.CPU_DEVICE): 349 output_gradients = array_ops.placeholder( 350 dtypes.float32, shape=output_vals.shape) 351 expected_input_gradients = pool_grad_func( 352 inputs, 353 outputs, 354 output_gradients, 355 ksize=ksize, 356 strides=strides, 357 padding=padding, 358 data_format="NHWC") 359 expected_input_gradient_vals = sess.run( 360 expected_input_gradients, 361 {inputs: x, 362 output_gradients: output_gradient_vals}) 363 364 output_grad_gradients = array_ops.placeholder( 365 dtypes.float32, shape=expected_input_gradient_vals.shape) 366 if pool_grad_grad_func is not None: 367 expected_grad_gradients = pool_grad_grad_func( 368 inputs, 369 outputs, 370 output_grad_gradients, 371 ksize=ksize, 372 strides=strides, 373 padding=padding, 374 data_format="NHWC") 375 expected_grad_gradients_vals = sess.run(expected_grad_gradients, { 376 inputs: x, 377 output_grad_gradients: output_grad_grad_vals 378 }) 379 380 # Run the gradient op on the XLA device 381 with self.test_scope(): 382 outputs = array_ops.placeholder(dtypes.float32, shape=output_vals.shape) 383 xla_inputs = inputs 384 xla_outputs = outputs 385 xla_output_gradients = output_gradients 386 xla_output_grad_gradients = output_grad_gradients 387 xla_ksize = ksize 388 xla_strides = strides 389 if data_format == "NCHW": 390 xla_inputs = NHWCToNCHW(inputs) 391 xla_outputs = NHWCToNCHW(outputs) 392 xla_output_gradients = NHWCToNCHW(output_gradients) 393 xla_output_grad_gradients = NHWCToNCHW(output_grad_gradients) 394 xla_ksize = NHWCToNCHW(ksize) 395 xla_strides = NHWCToNCHW(strides) 396 actual_input_gradients = pool_grad_func( 397 xla_inputs, 398 xla_outputs, 399 xla_output_gradients, 400 ksize=xla_ksize, 401 strides=xla_strides, 402 padding=padding, 403 data_format=data_format) 404 if data_format == "NCHW": 405 actual_input_gradients = NCHWToNHWC(actual_input_gradients) 406 if pool_grad_grad_func is not None: 407 actual_grad_gradients = pool_grad_grad_func( 408 xla_inputs, 409 xla_outputs, 410 xla_output_grad_gradients, 411 ksize=xla_ksize, 412 strides=xla_strides, 413 padding=padding, 414 data_format=data_format) 415 if data_format == "NCHW": 416 actual_grad_gradients = NCHWToNHWC(actual_grad_gradients) 417 actual_input_gradients_vals = sess.run(actual_input_gradients, { 418 inputs: x, 419 outputs: output_vals, 420 output_gradients: output_gradient_vals 421 }) 422 # Compare the Tensorflow and XLA results. 423 self.assertAllClose( 424 expected_input_gradient_vals, 425 actual_input_gradients_vals, 426 rtol=1e-4, 427 atol=1e-6) 428 self.assertShapeEqual(actual_input_gradients_vals, inputs) 429 430 if pool_grad_grad_func is not None: 431 actual_grad_gradients_vals = sess.run( 432 actual_grad_gradients, { 433 inputs: x, 434 outputs: output_vals, 435 output_grad_gradients: output_grad_grad_vals 436 }) 437 438 # Compare the Tensorflow and XLA results. 439 self.assertAllClose( 440 expected_grad_gradients_vals, 441 actual_grad_gradients_vals, 442 rtol=1e-4, 443 atol=1e-6) 444 self.assertShapeEqual(actual_grad_gradients_vals, outputs) 445 446 def _VerifyValues(self, 447 pool_func, 448 pool_grad_func, 449 input_sizes, 450 ksize, 451 strides, 452 padding, 453 pool_grad_grad_func=None): 454 """Verifies the output values of the pooling function. 455 456 Args: 457 pool_func: Pooling function to be called, e.g., tf.nn.max_pool 458 pool_grad_func: Corresponding pooling gradient function. 459 input_sizes: Input tensor dimensions. 460 ksize: The kernel size dimensions 461 strides: The stride dimensions 462 padding: Padding type. 463 pool_grad_grad_func: Second-order gradient function, if available. 464 """ 465 for data_format in GetTestConfigs(): 466 self._VerifyOneTest( 467 pool_func, 468 pool_grad_func, 469 input_sizes, 470 ksize, 471 strides, 472 padding, 473 data_format, 474 pool_grad_grad_func=pool_grad_grad_func) 475 476 def _TestPooling(self, forward_op, backward_op, pool_grad_grad_func=None): 477 # VALID padding 478 self._VerifyValues( 479 forward_op, 480 backward_op, 481 input_sizes=[1, 3, 3, 3], 482 ksize=[1, 2, 2, 1], 483 strides=[1, 2, 2, 1], 484 padding="VALID", 485 pool_grad_grad_func=pool_grad_grad_func) 486 487 # SAME padding 488 self._VerifyValues( 489 forward_op, 490 backward_op, 491 input_sizes=[1, 2, 3, 3], 492 ksize=[1, 2, 2, 1], 493 strides=[1, 2, 2, 1], 494 padding="SAME", 495 pool_grad_grad_func=pool_grad_grad_func) 496 497 # SAME padding, non square window 498 self._VerifyValues( 499 forward_op, 500 backward_op, 501 input_sizes=[1, 2, 2, 1], 502 ksize=[1, 1, 2, 1], 503 strides=[1, 1, 1, 1], 504 padding="SAME", 505 pool_grad_grad_func=pool_grad_grad_func) 506 507 # VALID padding, uneven stride 508 self._VerifyValues( 509 forward_op, 510 backward_op, 511 input_sizes=[1, 4, 4, 1], 512 ksize=[1, 2, 2, 1], 513 strides=[1, 1, 2, 1], 514 padding="VALID", 515 pool_grad_grad_func=pool_grad_grad_func) 516 self._VerifyValues( 517 forward_op, 518 backward_op, 519 input_sizes=[1, 4, 4, 1], 520 ksize=[1, 2, 2, 1], 521 strides=[1, 2, 1, 1], 522 padding="VALID", 523 pool_grad_grad_func=pool_grad_grad_func) 524 525 # SAME padding, size 4 input 526 self._VerifyValues( 527 forward_op, 528 backward_op, 529 input_sizes=[1, 4, 4, 4], 530 ksize=[1, 2, 2, 1], 531 strides=[1, 2, 2, 1], 532 padding="SAME", 533 pool_grad_grad_func=pool_grad_grad_func) 534 535 # SAME padding, size 8 input 536 self._VerifyValues( 537 forward_op, 538 backward_op, 539 input_sizes=[1, 8, 8, 8], 540 ksize=[1, 3, 3, 1], 541 strides=[1, 2, 2, 1], 542 padding="SAME", 543 pool_grad_grad_func=pool_grad_grad_func) 544 545 def testMaxPool(self): 546 self._TestPooling( 547 nn_ops.max_pool, 548 gen_nn_ops.max_pool_grad, 549 pool_grad_grad_func=gen_nn_ops.max_pool_grad_grad) 550 551 def testAvgPool(self): 552 # Wrapper around AvgPoolGrad that ignores extra arguments needed by 553 # MaxPoolGrad. 554 def AvgPoolGrad(inputs, outputs, output_gradients, ksize, strides, padding, 555 data_format): 556 del outputs # Unused by average-pooling gradients. 557 return gen_nn_ops.avg_pool_grad( 558 inputs.get_shape().as_list(), 559 output_gradients, 560 ksize=ksize, 561 strides=strides, 562 padding=padding, 563 data_format=data_format) 564 565 self._TestPooling(nn_ops.avg_pool, AvgPoolGrad) 566 567 # The CPU implementation of AvgPoolGrad doesn't accept kernels smaller than 568 # the stride size, so we only run the following tests on MaxPoolGrad. 569 570 def testMaxPoolKernelSmallerThanStrideValid(self): 571 self._VerifyValues( 572 nn_ops.max_pool, 573 gen_nn_ops.max_pool_grad, 574 input_sizes=[1, 7, 7, 1], 575 ksize=[1, 2, 2, 1], 576 strides=[1, 3, 3, 1], 577 padding="VALID") 578 579 def testMaxPoolKernelSmallerThanStrideSame(self): 580 self._VerifyValues( 581 nn_ops.max_pool, 582 gen_nn_ops.max_pool_grad, 583 input_sizes=[1, 3, 3, 1], 584 ksize=[1, 1, 1, 1], 585 strides=[1, 2, 2, 1], 586 padding="SAME") 587 588 self._VerifyValues( 589 nn_ops.max_pool, 590 gen_nn_ops.max_pool_grad, 591 input_sizes=[1, 4, 4, 1], 592 ksize=[1, 1, 1, 1], 593 strides=[1, 2, 2, 1], 594 padding="SAME") 595 596 597if __name__ == "__main__": 598 googletest.main() 599