1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for convolution related functionality in tensorflow.ops.nn.""" 16 17import numpy as np 18 19from tensorflow.python.framework import constant_op 20from tensorflow.python.framework import dtypes 21from tensorflow.python.framework import test_util 22from tensorflow.python.ops import gradient_checker 23from tensorflow.python.ops import nn_ops 24import tensorflow.python.ops.nn_grad # pylint: disable=unused-import 25from tensorflow.python.platform import test 26 27 28class Conv3DTransposeTest(test.TestCase): 29 30 def testConv3DTransposeSingleStride(self): 31 with self.cached_session(): 32 strides = [1, 1, 1, 1, 1] 33 34 # Input, output: [batch, depth, height, width, channel] 35 x_shape = [2, 5, 6, 4, 3] 36 y_shape = [2, 5, 6, 4, 2] 37 38 # Filter: [kernel_depth, kernel_height, kernel_width, out_depth, in_depth] 39 f_shape = [3, 3, 3, 2, 3] 40 41 x = constant_op.constant( 42 1.0, shape=x_shape, name="x", dtype=dtypes.float32) 43 f = constant_op.constant( 44 1.0, shape=f_shape, name="filter", dtype=dtypes.float32) 45 output = nn_ops.conv3d_transpose( 46 x, f, y_shape, strides=strides, padding="SAME") 47 value = self.evaluate(output) 48 49 # We count the number of cells being added at the locations in the output. 50 # At the center, #cells = kernel_depth * kernel_height * kernel_width 51 # At the corners, #cells = ceil(kernel_depth/2) * ceil(kernel_height/2) 52 # * ceil(kernel_width/2) 53 # At the edges, #cells = 54 # kernel_depth * ceil(kernel_height/2) * ceil(kernel_width/2) or 55 # ceil(kernel_depth/2) * kernel_height * ceil(kernel_width/2) or 56 # ceil(kernel_depth/2) * ceil(kernel_height/2) * kernel_width 57 # At the borders, #cells = 58 # ceil(kernel_depth/2) * kernel_height * kernel_width or 59 # kernel_depth * ceil(kernel_height/2) * kernel_width or 60 # kernel_depth * kernel_height * ceil(kernel_width/2) 61 62 for n in range(x_shape[0]): 63 for k in range(f_shape[3]): 64 for w in range(y_shape[3]): 65 for h in range(y_shape[2]): 66 for d in range(y_shape[1]): 67 d_in = d > 0 and d < y_shape[1] - 1 68 h_in = h > 0 and h < y_shape[2] - 1 69 w_in = w > 0 and w < y_shape[3] - 1 70 if d_in + h_in + w_in == 3: 71 target = 27 * 3.0 72 elif d_in + h_in + w_in == 2: 73 target = 18 * 3.0 74 elif d_in or h_in or w_in: 75 target = 12 * 3.0 76 else: 77 target = 8 * 3.0 78 self.assertAllClose(target, value[n, d, h, w, k]) 79 80 def testConv3DTransposeSame(self): 81 with self.cached_session(): 82 strides = [1, 2, 2, 2, 1] 83 84 # Input, output: [batch, depth, height, width, depth] 85 x_shape = [2, 5, 6, 4, 3] 86 y_shape = [2, 10, 12, 8, 2] 87 88 # Filter: [kernel_depth, kernel_height, kernel_width, out_depth, in_depth] 89 f_shape = [3, 3, 3, 2, 3] 90 91 x = constant_op.constant( 92 1.0, shape=x_shape, name="x", dtype=dtypes.float32) 93 f = constant_op.constant( 94 1.0, shape=f_shape, name="filter", dtype=dtypes.float32) 95 output = nn_ops.conv3d_transpose( 96 x, f, y_shape, strides=strides, padding="SAME") 97 value = self.evaluate(output) 98 99 for n in range(x_shape[0]): 100 for k in range(f_shape[3]): 101 for w in range(y_shape[3]): 102 for h in range(y_shape[2]): 103 for d in range(y_shape[1]): 104 # We add a case for locations divisible by the stride. 105 d_in = d % strides[1] == 0 and 0 < d < y_shape[1] - 1 106 h_in = h % strides[2] == 0 and 0 < h < y_shape[2] - 1 107 w_in = w % strides[3] == 0 and 0 < w < y_shape[3] - 1 108 if d_in + h_in + w_in == 3: 109 target = 8 * 3.0 110 elif d_in + h_in + w_in == 2: 111 target = 4 * 3.0 112 elif d_in or h_in or w_in: 113 target = 2 * 3.0 114 else: 115 target = 3.0 116 self.assertAllClose(target, value[n, d, h, w, k]) 117 118 @test_util.run_deprecated_v1 119 def testConv3DTransposeShapeMismatch(self): 120 # Test case for GitHub issue 18460 121 x_shape = [2, 2, 3, 4, 3] 122 f_shape = [3, 3, 3, 2, 2] 123 y_shape = [2, 2, 6, 8, 6] 124 strides = [1, 1, 2, 2, 2] 125 np.random.seed(1) 126 x_value = np.random.random_sample(x_shape).astype(np.float64) 127 f_value = np.random.random_sample(f_shape).astype(np.float64) 128 nn_ops.conv3d_transpose( 129 x_value, f_value, y_shape, strides, data_format='NCDHW') 130 131 def testConv3DTransposeOutputShapeType(self): 132 # Test case for GitHub issue 18887 133 for dtype in [dtypes.int32, dtypes.int64]: 134 with self.cached_session(): 135 x_shape = [2, 5, 6, 4, 3] 136 y_shape = [2, 5, 6, 4, 2] 137 f_shape = [3, 3, 3, 2, 3] 138 strides = [1, 1, 1, 1, 1] 139 x_value = constant_op.constant( 140 1.0, shape=x_shape, name="x", dtype=dtypes.float32) 141 f_value = constant_op.constant( 142 1.0, shape=f_shape, name="filter", dtype=dtypes.float32) 143 output = nn_ops.conv3d_transpose( 144 x_value, f_value, constant_op.constant(y_shape, dtype=dtype), 145 strides=strides, padding="SAME") 146 self.evaluate(output) 147 148 def testConv3DTransposeValid(self): 149 with self.cached_session(): 150 strides = [1, 2, 2, 2, 1] 151 152 # Input, output: [batch, depth, height, width, depth] 153 x_shape = [2, 5, 6, 4, 3] 154 y_shape = [2, 11, 13, 9, 2] 155 156 # Filter: [kernel_depth, kernel_height, kernel_width, out_depth, in_depth] 157 f_shape = [3, 3, 3, 2, 3] 158 159 x = constant_op.constant( 160 1.0, shape=x_shape, name="x", dtype=dtypes.float32) 161 f = constant_op.constant( 162 1.0, shape=f_shape, name="filter", dtype=dtypes.float32) 163 output = nn_ops.conv3d_transpose( 164 x, f, y_shape, strides=strides, padding="VALID") 165 value = self.evaluate(output) 166 167 cache_values = np.zeros(y_shape, dtype=np.float32) 168 169 # The amount of padding added 170 pad = 1 171 172 for n in range(x_shape[0]): 173 for k in range(f_shape[3]): 174 for w in range(y_shape[3]): 175 for h in range(y_shape[2]): 176 for d in range(y_shape[1]): 177 # We add a case for locations divisible by the stride. 178 d_in = d % strides[1] == 0 and pad < d < y_shape[1] - 1 - pad 179 h_in = h % strides[2] == 0 and pad < h < y_shape[2] - 1 - pad 180 w_in = w % strides[3] == 0 and pad < w < y_shape[3] - 1 - pad 181 if d_in + h_in + w_in == 3: 182 target = 8 * 3.0 183 elif d_in + h_in + w_in == 2: 184 target = 4 * 3.0 185 elif d_in or h_in or w_in: 186 target = 2 * 3.0 187 else: 188 target = 3.0 189 cache_values[n, d, h, w, k] = target 190 191 # copy values in the border 192 cache_values[n, :, :, 0, k] = cache_values[n, :, :, 1, k] 193 cache_values[n, :, :, -1, k] = cache_values[n, :, :, -2, k] 194 cache_values[n, :, 0, :, k] = cache_values[n, :, 1, :, k] 195 cache_values[n, :, -1, :, k] = cache_values[n, :, -2, :, k] 196 cache_values[n, 0, :, :, k] = cache_values[n, 1, :, :, k] 197 cache_values[n, -1, :, :, k] = cache_values[n, -2, :, :, k] 198 199 self.assertAllClose(cache_values, value) 200 201 @test_util.run_deprecated_v1 202 def testGradient(self): 203 x_shape = [2, 3, 4, 3, 2] 204 f_shape = [3, 3, 3, 2, 2] 205 y_shape = [2, 6, 8, 6, 2] 206 strides = [1, 2, 2, 2, 1] 207 np.random.seed(1) # Make it reproducible. 208 x_val = np.random.random_sample(x_shape).astype(np.float64) 209 f_val = np.random.random_sample(f_shape).astype(np.float64) 210 with self.cached_session(): 211 x = constant_op.constant(x_val, name="x", dtype=dtypes.float32) 212 f = constant_op.constant(f_val, name="f", dtype=dtypes.float32) 213 output = nn_ops.conv3d_transpose( 214 x, f, y_shape, strides=strides, padding="SAME") 215 err = gradient_checker.compute_gradient_error([x, f], [x_shape, f_shape], 216 output, y_shape) 217 print("conv3d_transpose gradient err = %g " % err) 218 err_tolerance = 0.00055 219 self.assertLess(err, err_tolerance) 220 221 222if __name__ == "__main__": 223 test.main() 224