• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for sparse_image_warp."""
16from __future__ import absolute_import
17from __future__ import division
18from __future__ import print_function
19
20import numpy as np
21
22from tensorflow.contrib.image.python.ops import sparse_image_warp
23
24from tensorflow.python.framework import constant_op
25from tensorflow.python.framework import dtypes
26from tensorflow.python.framework import test_util
27from tensorflow.python.ops import clip_ops
28from tensorflow.python.ops import gradients
29from tensorflow.python.ops import image_ops
30from tensorflow.python.ops import io_ops
31from tensorflow.python.ops import math_ops
32from tensorflow.python.ops import variables
33from tensorflow.python.platform import googletest
34from tensorflow.python.platform import test
35
36from tensorflow.python.training import momentum
37
38
39class SparseImageWarpTest(test_util.TensorFlowTestCase):
40
41  def setUp(self):
42    np.random.seed(0)
43
44  def testGetBoundaryLocations(self):
45    image_height = 11
46    image_width = 11
47    num_points_per_edge = 4
48    locs = sparse_image_warp._get_boundary_locations(image_height, image_width,
49                                                     num_points_per_edge)
50    num_points = locs.shape[0]
51    self.assertEqual(num_points, 4 + 4 * num_points_per_edge)
52    locs = [(locs[i, 0], locs[i, 1]) for i in range(num_points)]
53    for i in (0, image_height - 1):
54      for j in (0, image_width - 1):
55        self.assertIn((i, j), locs, '{},{} not in the locations'.format(i, j))
56
57      for i in (2, 4, 6, 8):
58        for j in (0, image_width - 1):
59          self.assertIn((i, j), locs, '{},{} not in the locations'.format(i, j))
60
61      for i in (0, image_height - 1):
62        for j in (2, 4, 6, 8):
63          self.assertIn((i, j), locs, '{},{} not in the locations'.format(i, j))
64
65  def testGetGridLocations(self):
66    image_height = 5
67    image_width = 3
68    grid = sparse_image_warp._get_grid_locations(image_height, image_width)
69    for i in range(image_height):
70      for j in range(image_width):
71        self.assertEqual(grid[i, j, 0], i)
72        self.assertEqual(grid[i, j, 1], j)
73
74  def testZeroShift(self):
75    """Run assertZeroShift for various hyperparameters."""
76    for order in (1, 2):
77      for regularization in (0, 0.01):
78        for num_boundary_points in (0, 1):
79          self.assertZeroShift(order, regularization, num_boundary_points)
80
81  def assertZeroShift(self, order, regularization, num_boundary_points):
82    """Check that warping with zero displacements doesn't change the image."""
83    batch_size = 1
84    image_height = 4
85    image_width = 4
86    channels = 3
87
88    image = np.random.uniform(
89        size=[batch_size, image_height, image_width, channels])
90
91    input_image_op = constant_op.constant(np.float32(image))
92
93    control_point_locations = [[1., 1.], [2., 2.], [2., 1.]]
94    control_point_locations = constant_op.constant(
95        np.float32(np.expand_dims(control_point_locations, 0)))
96
97    control_point_displacements = np.zeros(
98        control_point_locations.shape.as_list())
99    control_point_displacements = constant_op.constant(
100        np.float32(control_point_displacements))
101
102    (warped_image_op, flow_field) = sparse_image_warp.sparse_image_warp(
103        input_image_op,
104        control_point_locations,
105        control_point_locations + control_point_displacements,
106        interpolation_order=order,
107        regularization_weight=regularization,
108        num_boundary_points=num_boundary_points)
109
110    with self.cached_session() as sess:
111      warped_image, input_image, _ = sess.run(
112          [warped_image_op, input_image_op, flow_field])
113
114      self.assertAllClose(warped_image, input_image)
115
116  def testMoveSinglePixel(self):
117    """Run assertMoveSinglePixel for various hyperparameters and data types."""
118    for order in (1, 2):
119      for num_boundary_points in (1, 2):
120        for type_to_use in (dtypes.float32, dtypes.float64):
121          self.assertMoveSinglePixel(order, num_boundary_points, type_to_use)
122
123  def assertMoveSinglePixel(self, order, num_boundary_points, type_to_use):
124    """Move a single block in a small grid using warping."""
125    batch_size = 1
126    image_height = 7
127    image_width = 7
128    channels = 3
129
130    image = np.zeros([batch_size, image_height, image_width, channels])
131    image[:, 3, 3, :] = 1.0
132    input_image_op = constant_op.constant(image, dtype=type_to_use)
133
134    # Place a control point at the one white pixel.
135    control_point_locations = [[3., 3.]]
136    control_point_locations = constant_op.constant(
137        np.float32(np.expand_dims(control_point_locations, 0)),
138        dtype=type_to_use)
139    # Shift it one pixel to the right.
140    control_point_displacements = [[0., 1.0]]
141    control_point_displacements = constant_op.constant(
142        np.float32(np.expand_dims(control_point_displacements, 0)),
143        dtype=type_to_use)
144
145    (warped_image_op, flow_field) = sparse_image_warp.sparse_image_warp(
146        input_image_op,
147        control_point_locations,
148        control_point_locations + control_point_displacements,
149        interpolation_order=order,
150        num_boundary_points=num_boundary_points)
151
152    with self.cached_session() as sess:
153      warped_image, input_image, flow = sess.run(
154          [warped_image_op, input_image_op, flow_field])
155      # Check that it moved the pixel correctly.
156      self.assertAllClose(
157          warped_image[0, 4, 5, :],
158          input_image[0, 4, 4, :],
159          atol=1e-5,
160          rtol=1e-5)
161
162      # Test that there is no flow at the corners.
163      for i in (0, image_height - 1):
164        for j in (0, image_width - 1):
165          self.assertAllClose(
166              flow[0, i, j, :], np.zeros([2]), atol=1e-5, rtol=1e-5)
167
168  def load_image(self, image_file, sess):
169    image_op = image_ops.decode_png(
170        io_ops.read_file(image_file), dtype=dtypes.uint8, channels=4)[:, :, 0:3]
171    return sess.run(image_op)
172
173  def testSmileyFace(self):
174    """Check warping accuracy by comparing to hardcoded warped images."""
175
176    test_data_dir = test.test_src_dir_path('contrib/image/python/'
177                                           'kernel_tests/test_data/')
178    input_file = test_data_dir + 'Yellow_Smiley_Face.png'
179    with self.cached_session() as sess:
180      input_image = self.load_image(input_file, sess)
181    control_points = np.asarray([[64, 59], [180 - 64, 59], [39, 111],
182                                 [180 - 39, 111], [90, 143], [58, 134],
183                                 [180 - 58, 134]])  # pyformat: disable
184    control_point_displacements = np.asarray(
185        [[-10.5, 10.5], [10.5, 10.5], [0, 0], [0, 0], [0, -10], [-20, 10.25],
186         [10, 10.75]])
187    control_points_op = constant_op.constant(
188        np.expand_dims(np.float32(control_points[:, [1, 0]]), 0))
189    control_point_displacements_op = constant_op.constant(
190        np.expand_dims(np.float32(control_point_displacements[:, [1, 0]]), 0))
191    float_image = np.expand_dims(np.float32(input_image) / 255, 0)
192    input_image_op = constant_op.constant(float_image)
193
194    for interpolation_order in (1, 2, 3):
195      for num_boundary_points in (0, 1, 4):
196        warp_op, _ = sparse_image_warp.sparse_image_warp(
197            input_image_op,
198            control_points_op,
199            control_points_op + control_point_displacements_op,
200            interpolation_order=interpolation_order,
201            num_boundary_points=num_boundary_points)
202        with self.cached_session() as sess:
203          warped_image = sess.run(warp_op)
204          out_image = np.uint8(warped_image[0, :, :, :] * 255)
205          target_file = (
206              test_data_dir +
207              'Yellow_Smiley_Face_Warp-interp' + '-{}-clamp-{}.png'.format(
208                  interpolation_order, num_boundary_points))
209
210          target_image = self.load_image(target_file, sess)
211
212          # Check that the target_image and out_image difference is no
213          # bigger than 2 (on a scale of 0-255). Due to differences in
214          # floating point computation on different devices, the float
215          # output in warped_image may get rounded to a different int
216          # than that in the saved png file loaded into target_image.
217          self.assertAllClose(target_image, out_image, atol=2, rtol=1e-3)
218
219  def testThatBackpropRuns(self):
220    """Run optimization to ensure that gradients can be computed."""
221
222    batch_size = 1
223    image_height = 9
224    image_width = 12
225    image = variables.Variable(
226        np.float32(
227            np.random.uniform(size=[batch_size, image_height, image_width, 3])))
228    control_point_locations = [[3., 3.]]
229    control_point_locations = constant_op.constant(
230        np.float32(np.expand_dims(control_point_locations, 0)))
231    control_point_displacements = [[0.25, -0.5]]
232    control_point_displacements = constant_op.constant(
233        np.float32(np.expand_dims(control_point_displacements, 0)))
234    warped_image, _ = sparse_image_warp.sparse_image_warp(
235        image,
236        control_point_locations,
237        control_point_locations + control_point_displacements,
238        num_boundary_points=3)
239
240    loss = math_ops.reduce_mean(math_ops.abs(warped_image - image))
241    optimizer = momentum.MomentumOptimizer(0.001, 0.9)
242    grad = gradients.gradients(loss, [image])
243    grad, _ = clip_ops.clip_by_global_norm(grad, 1.0)
244    opt_func = optimizer.apply_gradients(zip(grad, [image]))
245    init_op = variables.global_variables_initializer()
246
247    with self.cached_session() as sess:
248      sess.run(init_op)
249      for _ in range(5):
250        sess.run([loss, opt_func])
251
252
253if __name__ == '__main__':
254  googletest.main()
255