1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Tests for sparse_image_warp.""" 16from __future__ import absolute_import 17from __future__ import division 18from __future__ import print_function 19 20import numpy as np 21 22from tensorflow.contrib.image.python.ops import sparse_image_warp 23 24from tensorflow.python.framework import constant_op 25from tensorflow.python.framework import dtypes 26from tensorflow.python.framework import test_util 27from tensorflow.python.ops import clip_ops 28from tensorflow.python.ops import gradients 29from tensorflow.python.ops import image_ops 30from tensorflow.python.ops import io_ops 31from tensorflow.python.ops import math_ops 32from tensorflow.python.ops import variables 33from tensorflow.python.platform import googletest 34from tensorflow.python.platform import test 35 36from tensorflow.python.training import momentum 37 38 39class SparseImageWarpTest(test_util.TensorFlowTestCase): 40 41 def setUp(self): 42 np.random.seed(0) 43 44 def testGetBoundaryLocations(self): 45 image_height = 11 46 image_width = 11 47 num_points_per_edge = 4 48 locs = sparse_image_warp._get_boundary_locations(image_height, image_width, 49 num_points_per_edge) 50 num_points = locs.shape[0] 51 self.assertEqual(num_points, 4 + 4 * num_points_per_edge) 52 locs = [(locs[i, 0], locs[i, 1]) for i in range(num_points)] 53 for i in (0, image_height - 1): 54 for j in (0, image_width - 1): 55 self.assertIn((i, j), locs, '{},{} not in the locations'.format(i, j)) 56 57 for i in (2, 4, 6, 8): 58 for j in (0, image_width - 1): 59 self.assertIn((i, j), locs, '{},{} not in the locations'.format(i, j)) 60 61 for i in (0, image_height - 1): 62 for j in (2, 4, 6, 8): 63 self.assertIn((i, j), locs, '{},{} not in the locations'.format(i, j)) 64 65 def testGetGridLocations(self): 66 image_height = 5 67 image_width = 3 68 grid = sparse_image_warp._get_grid_locations(image_height, image_width) 69 for i in range(image_height): 70 for j in range(image_width): 71 self.assertEqual(grid[i, j, 0], i) 72 self.assertEqual(grid[i, j, 1], j) 73 74 def testZeroShift(self): 75 """Run assertZeroShift for various hyperparameters.""" 76 for order in (1, 2): 77 for regularization in (0, 0.01): 78 for num_boundary_points in (0, 1): 79 self.assertZeroShift(order, regularization, num_boundary_points) 80 81 def assertZeroShift(self, order, regularization, num_boundary_points): 82 """Check that warping with zero displacements doesn't change the image.""" 83 batch_size = 1 84 image_height = 4 85 image_width = 4 86 channels = 3 87 88 image = np.random.uniform( 89 size=[batch_size, image_height, image_width, channels]) 90 91 input_image_op = constant_op.constant(np.float32(image)) 92 93 control_point_locations = [[1., 1.], [2., 2.], [2., 1.]] 94 control_point_locations = constant_op.constant( 95 np.float32(np.expand_dims(control_point_locations, 0))) 96 97 control_point_displacements = np.zeros( 98 control_point_locations.shape.as_list()) 99 control_point_displacements = constant_op.constant( 100 np.float32(control_point_displacements)) 101 102 (warped_image_op, flow_field) = sparse_image_warp.sparse_image_warp( 103 input_image_op, 104 control_point_locations, 105 control_point_locations + control_point_displacements, 106 interpolation_order=order, 107 regularization_weight=regularization, 108 num_boundary_points=num_boundary_points) 109 110 with self.cached_session() as sess: 111 warped_image, input_image, _ = sess.run( 112 [warped_image_op, input_image_op, flow_field]) 113 114 self.assertAllClose(warped_image, input_image) 115 116 def testMoveSinglePixel(self): 117 """Run assertMoveSinglePixel for various hyperparameters and data types.""" 118 for order in (1, 2): 119 for num_boundary_points in (1, 2): 120 for type_to_use in (dtypes.float32, dtypes.float64): 121 self.assertMoveSinglePixel(order, num_boundary_points, type_to_use) 122 123 def assertMoveSinglePixel(self, order, num_boundary_points, type_to_use): 124 """Move a single block in a small grid using warping.""" 125 batch_size = 1 126 image_height = 7 127 image_width = 7 128 channels = 3 129 130 image = np.zeros([batch_size, image_height, image_width, channels]) 131 image[:, 3, 3, :] = 1.0 132 input_image_op = constant_op.constant(image, dtype=type_to_use) 133 134 # Place a control point at the one white pixel. 135 control_point_locations = [[3., 3.]] 136 control_point_locations = constant_op.constant( 137 np.float32(np.expand_dims(control_point_locations, 0)), 138 dtype=type_to_use) 139 # Shift it one pixel to the right. 140 control_point_displacements = [[0., 1.0]] 141 control_point_displacements = constant_op.constant( 142 np.float32(np.expand_dims(control_point_displacements, 0)), 143 dtype=type_to_use) 144 145 (warped_image_op, flow_field) = sparse_image_warp.sparse_image_warp( 146 input_image_op, 147 control_point_locations, 148 control_point_locations + control_point_displacements, 149 interpolation_order=order, 150 num_boundary_points=num_boundary_points) 151 152 with self.cached_session() as sess: 153 warped_image, input_image, flow = sess.run( 154 [warped_image_op, input_image_op, flow_field]) 155 # Check that it moved the pixel correctly. 156 self.assertAllClose( 157 warped_image[0, 4, 5, :], 158 input_image[0, 4, 4, :], 159 atol=1e-5, 160 rtol=1e-5) 161 162 # Test that there is no flow at the corners. 163 for i in (0, image_height - 1): 164 for j in (0, image_width - 1): 165 self.assertAllClose( 166 flow[0, i, j, :], np.zeros([2]), atol=1e-5, rtol=1e-5) 167 168 def load_image(self, image_file, sess): 169 image_op = image_ops.decode_png( 170 io_ops.read_file(image_file), dtype=dtypes.uint8, channels=4)[:, :, 0:3] 171 return sess.run(image_op) 172 173 def testSmileyFace(self): 174 """Check warping accuracy by comparing to hardcoded warped images.""" 175 176 test_data_dir = test.test_src_dir_path('contrib/image/python/' 177 'kernel_tests/test_data/') 178 input_file = test_data_dir + 'Yellow_Smiley_Face.png' 179 with self.cached_session() as sess: 180 input_image = self.load_image(input_file, sess) 181 control_points = np.asarray([[64, 59], [180 - 64, 59], [39, 111], 182 [180 - 39, 111], [90, 143], [58, 134], 183 [180 - 58, 134]]) # pyformat: disable 184 control_point_displacements = np.asarray( 185 [[-10.5, 10.5], [10.5, 10.5], [0, 0], [0, 0], [0, -10], [-20, 10.25], 186 [10, 10.75]]) 187 control_points_op = constant_op.constant( 188 np.expand_dims(np.float32(control_points[:, [1, 0]]), 0)) 189 control_point_displacements_op = constant_op.constant( 190 np.expand_dims(np.float32(control_point_displacements[:, [1, 0]]), 0)) 191 float_image = np.expand_dims(np.float32(input_image) / 255, 0) 192 input_image_op = constant_op.constant(float_image) 193 194 for interpolation_order in (1, 2, 3): 195 for num_boundary_points in (0, 1, 4): 196 warp_op, _ = sparse_image_warp.sparse_image_warp( 197 input_image_op, 198 control_points_op, 199 control_points_op + control_point_displacements_op, 200 interpolation_order=interpolation_order, 201 num_boundary_points=num_boundary_points) 202 with self.cached_session() as sess: 203 warped_image = sess.run(warp_op) 204 out_image = np.uint8(warped_image[0, :, :, :] * 255) 205 target_file = ( 206 test_data_dir + 207 'Yellow_Smiley_Face_Warp-interp' + '-{}-clamp-{}.png'.format( 208 interpolation_order, num_boundary_points)) 209 210 target_image = self.load_image(target_file, sess) 211 212 # Check that the target_image and out_image difference is no 213 # bigger than 2 (on a scale of 0-255). Due to differences in 214 # floating point computation on different devices, the float 215 # output in warped_image may get rounded to a different int 216 # than that in the saved png file loaded into target_image. 217 self.assertAllClose(target_image, out_image, atol=2, rtol=1e-3) 218 219 def testThatBackpropRuns(self): 220 """Run optimization to ensure that gradients can be computed.""" 221 222 batch_size = 1 223 image_height = 9 224 image_width = 12 225 image = variables.Variable( 226 np.float32( 227 np.random.uniform(size=[batch_size, image_height, image_width, 3]))) 228 control_point_locations = [[3., 3.]] 229 control_point_locations = constant_op.constant( 230 np.float32(np.expand_dims(control_point_locations, 0))) 231 control_point_displacements = [[0.25, -0.5]] 232 control_point_displacements = constant_op.constant( 233 np.float32(np.expand_dims(control_point_displacements, 0))) 234 warped_image, _ = sparse_image_warp.sparse_image_warp( 235 image, 236 control_point_locations, 237 control_point_locations + control_point_displacements, 238 num_boundary_points=3) 239 240 loss = math_ops.reduce_mean(math_ops.abs(warped_image - image)) 241 optimizer = momentum.MomentumOptimizer(0.001, 0.9) 242 grad = gradients.gradients(loss, [image]) 243 grad, _ = clip_ops.clip_by_global_norm(grad, 1.0) 244 opt_func = optimizer.apply_gradients(zip(grad, [image])) 245 init_op = variables.global_variables_initializer() 246 247 with self.cached_session() as sess: 248 sess.run(init_op) 249 for _ in range(5): 250 sess.run([loss, opt_func]) 251 252 253if __name__ == '__main__': 254 googletest.main() 255