1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for dense_image_warp.""" 16from __future__ import absolute_import 17from __future__ import division 18from __future__ import print_function 19 20import math 21import numpy as np 22 23from tensorflow.contrib.image.python.ops import dense_image_warp 24 25from tensorflow.python.framework import constant_op 26from tensorflow.python.framework import dtypes 27from tensorflow.python.framework import errors 28from tensorflow.python.framework import test_util 29from tensorflow.python.ops import array_ops 30from tensorflow.python.ops import gradients 31from tensorflow.python.ops import math_ops 32from tensorflow.python.ops import random_ops 33from tensorflow.python.ops import variables 34from tensorflow.python.platform import googletest 35 36from tensorflow.python.training import adam 37 38 39class DenseImageWarpTest(test_util.TensorFlowTestCase): 40 41 def setUp(self): 42 np.random.seed(0) 43 44 def test_interpolate_small_grid_ij(self): 45 grid = constant_op.constant( 46 [[0., 1., 2.], [3., 4., 5.], [6., 7., 8.]], shape=[1, 3, 3, 1]) 47 query_points = constant_op.constant( 48 [[0., 0.], [1., 0.], [2., 0.5], [1.5, 1.5]], shape=[1, 4, 2]) 49 expected_results = np.reshape(np.array([0., 3., 6.5, 6.]), [1, 4, 1]) 50 51 interp = dense_image_warp._interpolate_bilinear(grid, query_points) 52 53 with self.cached_session() as sess: 54 predicted = sess.run(interp) 55 self.assertAllClose(expected_results, predicted) 56 57 def test_interpolate_small_grid_xy(self): 58 grid = constant_op.constant( 59 [[0., 1., 2.], [3., 4., 5.], [6., 7., 8.]], shape=[1, 3, 3, 1]) 60 query_points = constant_op.constant( 61 [[0., 0.], [0., 1.], [0.5, 2.0], [1.5, 1.5]], shape=[1, 4, 2]) 62 expected_results = np.reshape(np.array([0., 3., 6.5, 6.]), [1, 4, 1]) 63 64 interp = dense_image_warp._interpolate_bilinear( 65 grid, query_points, indexing='xy') 66 67 with self.cached_session() as sess: 68 predicted = sess.run(interp) 69 self.assertAllClose(expected_results, predicted) 70 71 def test_interpolate_small_grid_batched(self): 72 grid = constant_op.constant( 73 [[[0., 1.], [3., 4.]], [[5., 6.], [7., 8.]]], shape=[2, 2, 2, 1]) 74 query_points = constant_op.constant([[[0., 0.], [1., 0.], [0.5, 0.5]], 75 [[0.5, 0.], [1., 0.], [1., 1.]]]) 76 expected_results = np.reshape( 77 np.array([[0., 3., 2.], [6., 7., 8.]]), [2, 3, 1]) 78 79 interp = dense_image_warp._interpolate_bilinear(grid, query_points) 80 81 with self.cached_session() as sess: 82 predicted = sess.run(interp) 83 self.assertAllClose(expected_results, predicted) 84 85 def get_image_and_flow_placeholders(self, shape, image_type, flow_type): 86 batch_size, height, width, numchannels = shape 87 image_shape = [batch_size, height, width, numchannels] 88 flow_shape = [batch_size, height, width, 2] 89 90 tf_type = { 91 'float16': dtypes.half, 92 'float32': dtypes.float32, 93 'float64': dtypes.float64 94 } 95 96 image = array_ops.placeholder(dtype=tf_type[image_type], shape=image_shape) 97 98 flows = array_ops.placeholder(dtype=tf_type[flow_type], shape=flow_shape) 99 return image, flows 100 101 def get_random_image_and_flows(self, shape, image_type, flow_type): 102 batch_size, height, width, numchannels = shape 103 image_shape = [batch_size, height, width, numchannels] 104 image = np.random.normal(size=image_shape) 105 flow_shape = [batch_size, height, width, 2] 106 flows = np.random.normal(size=flow_shape) * 3 107 return image.astype(image_type), flows.astype(flow_type) 108 109 def assert_correct_interpolation_value(self, 110 image, 111 flows, 112 pred_interpolation, 113 batch_index, 114 y_index, 115 x_index, 116 low_precision=False): 117 """Assert that the tf interpolation matches hand-computed value.""" 118 119 height = image.shape[1] 120 width = image.shape[2] 121 displacement = flows[batch_index, y_index, x_index, :] 122 float_y = y_index - displacement[0] 123 float_x = x_index - displacement[1] 124 floor_y = max(min(height - 2, math.floor(float_y)), 0) 125 floor_x = max(min(width - 2, math.floor(float_x)), 0) 126 ceil_y = floor_y + 1 127 ceil_x = floor_x + 1 128 129 alpha_y = min(max(0.0, float_y - floor_y), 1.0) 130 alpha_x = min(max(0.0, float_x - floor_x), 1.0) 131 132 floor_y = int(floor_y) 133 floor_x = int(floor_x) 134 ceil_y = int(ceil_y) 135 ceil_x = int(ceil_x) 136 137 top_left = image[batch_index, floor_y, floor_x, :] 138 top_right = image[batch_index, floor_y, ceil_x, :] 139 bottom_left = image[batch_index, ceil_y, floor_x, :] 140 bottom_right = image[batch_index, ceil_y, ceil_x, :] 141 142 interp_top = alpha_x * (top_right - top_left) + top_left 143 interp_bottom = alpha_x * (bottom_right - bottom_left) + bottom_left 144 interp = alpha_y * (interp_bottom - interp_top) + interp_top 145 atol = 1e-6 146 rtol = 1e-6 147 if low_precision: 148 atol = 1e-2 149 rtol = 1e-3 150 self.assertAllClose( 151 interp, 152 pred_interpolation[batch_index, y_index, x_index, :], 153 atol=atol, 154 rtol=rtol) 155 156 def check_zero_flow_correctness(self, shape, image_type, flow_type): 157 """Assert using zero flows doesn't change the input image.""" 158 159 image, flows = self.get_image_and_flow_placeholders(shape, image_type, 160 flow_type) 161 interp = dense_image_warp.dense_image_warp(image, flows) 162 163 with self.cached_session() as sess: 164 rand_image, rand_flows = self.get_random_image_and_flows( 165 shape, image_type, flow_type) 166 rand_flows *= 0 167 168 predicted_interpolation = sess.run( 169 interp, feed_dict={ 170 image: rand_image, 171 flows: rand_flows 172 }) 173 self.assertAllClose(rand_image, predicted_interpolation) 174 175 def test_zero_flows(self): 176 """Apply check_zero_flow_correctness() for a few sizes and types.""" 177 178 shapes_to_try = [[3, 4, 5, 6], [1, 2, 2, 1]] 179 for shape in shapes_to_try: 180 self.check_zero_flow_correctness( 181 shape, image_type='float32', flow_type='float32') 182 183 def check_interpolation_correctness(self, 184 shape, 185 image_type, 186 flow_type, 187 num_probes=5): 188 """Interpolate, and then assert correctness for a few query locations.""" 189 190 image, flows = self.get_image_and_flow_placeholders(shape, image_type, 191 flow_type) 192 interp = dense_image_warp.dense_image_warp(image, flows) 193 low_precision = image_type == 'float16' or flow_type == 'float16' 194 with self.cached_session() as sess: 195 rand_image, rand_flows = self.get_random_image_and_flows( 196 shape, image_type, flow_type) 197 198 pred_interpolation = sess.run( 199 interp, feed_dict={ 200 image: rand_image, 201 flows: rand_flows 202 }) 203 204 for _ in range(num_probes): 205 batch_index = np.random.randint(0, shape[0]) 206 y_index = np.random.randint(0, shape[1]) 207 x_index = np.random.randint(0, shape[2]) 208 209 self.assert_correct_interpolation_value( 210 rand_image, 211 rand_flows, 212 pred_interpolation, 213 batch_index, 214 y_index, 215 x_index, 216 low_precision=low_precision) 217 218 def test_interpolation(self): 219 """Apply check_interpolation_correctness() for a few sizes and types.""" 220 221 shapes_to_try = [[3, 4, 5, 6], [1, 5, 5, 3], [1, 2, 2, 1]] 222 for im_type in ['float32', 'float64', 'float16']: 223 for flow_type in ['float32', 'float64', 'float16']: 224 for shape in shapes_to_try: 225 self.check_interpolation_correctness(shape, im_type, flow_type) 226 227 def test_gradients_exist(self): 228 """Check that backprop can run. 229 230 The correctness of the gradients is assumed, since the forward propagation 231 is tested to be correct and we only use built-in tf ops. 232 However, we perform a simple test to make sure that backprop can actually 233 run. We treat the flows as a tf.Variable and optimize them to minimize 234 the difference between the interpolated image and the input image. 235 """ 236 237 batch_size, height, width, numchannels = [4, 5, 6, 7] 238 image_shape = [batch_size, height, width, numchannels] 239 image = random_ops.random_normal(image_shape) 240 flow_shape = [batch_size, height, width, 2] 241 init_flows = np.float32(np.random.normal(size=flow_shape) * 0.25) 242 flows = variables.Variable(init_flows) 243 244 interp = dense_image_warp.dense_image_warp(image, flows) 245 loss = math_ops.reduce_mean(math_ops.square(interp - image)) 246 247 optimizer = adam.AdamOptimizer(1.0) 248 grad = gradients.gradients(loss, [flows]) 249 opt_func = optimizer.apply_gradients(zip(grad, [flows])) 250 init_op = variables.global_variables_initializer() 251 252 with self.cached_session() as sess: 253 sess.run(init_op) 254 for _ in range(10): 255 sess.run(opt_func) 256 257 def test_size_exception(self): 258 """Make sure it throws an exception for images that are too small.""" 259 260 shape = [1, 2, 1, 1] 261 msg = 'Should have raised an exception for invalid image size' 262 with self.assertRaises(errors.InvalidArgumentError, msg=msg): 263 self.check_interpolation_correctness(shape, 'float32', 'float32') 264 265 266if __name__ == '__main__': 267 googletest.main() 268