• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for dense_image_warp."""
16from __future__ import absolute_import
17from __future__ import division
18from __future__ import print_function
19
20import math
21import numpy as np
22
23from tensorflow.contrib.image.python.ops import dense_image_warp
24
25from tensorflow.python.framework import constant_op
26from tensorflow.python.framework import dtypes
27from tensorflow.python.framework import errors
28from tensorflow.python.framework import test_util
29from tensorflow.python.ops import array_ops
30from tensorflow.python.ops import gradients
31from tensorflow.python.ops import math_ops
32from tensorflow.python.ops import random_ops
33from tensorflow.python.ops import variables
34from tensorflow.python.platform import googletest
35
36from tensorflow.python.training import adam
37
38
39class DenseImageWarpTest(test_util.TensorFlowTestCase):
40
41  def setUp(self):
42    np.random.seed(0)
43
44  def test_interpolate_small_grid_ij(self):
45    grid = constant_op.constant(
46        [[0., 1., 2.], [3., 4., 5.], [6., 7., 8.]], shape=[1, 3, 3, 1])
47    query_points = constant_op.constant(
48        [[0., 0.], [1., 0.], [2., 0.5], [1.5, 1.5]], shape=[1, 4, 2])
49    expected_results = np.reshape(np.array([0., 3., 6.5, 6.]), [1, 4, 1])
50
51    interp = dense_image_warp._interpolate_bilinear(grid, query_points)
52
53    with self.cached_session() as sess:
54      predicted = sess.run(interp)
55      self.assertAllClose(expected_results, predicted)
56
57  def test_interpolate_small_grid_xy(self):
58    grid = constant_op.constant(
59        [[0., 1., 2.], [3., 4., 5.], [6., 7., 8.]], shape=[1, 3, 3, 1])
60    query_points = constant_op.constant(
61        [[0., 0.], [0., 1.], [0.5, 2.0], [1.5, 1.5]], shape=[1, 4, 2])
62    expected_results = np.reshape(np.array([0., 3., 6.5, 6.]), [1, 4, 1])
63
64    interp = dense_image_warp._interpolate_bilinear(
65        grid, query_points, indexing='xy')
66
67    with self.cached_session() as sess:
68      predicted = sess.run(interp)
69      self.assertAllClose(expected_results, predicted)
70
71  def test_interpolate_small_grid_batched(self):
72    grid = constant_op.constant(
73        [[[0., 1.], [3., 4.]], [[5., 6.], [7., 8.]]], shape=[2, 2, 2, 1])
74    query_points = constant_op.constant([[[0., 0.], [1., 0.], [0.5, 0.5]],
75                                         [[0.5, 0.], [1., 0.], [1., 1.]]])
76    expected_results = np.reshape(
77        np.array([[0., 3., 2.], [6., 7., 8.]]), [2, 3, 1])
78
79    interp = dense_image_warp._interpolate_bilinear(grid, query_points)
80
81    with self.cached_session() as sess:
82      predicted = sess.run(interp)
83      self.assertAllClose(expected_results, predicted)
84
85  def get_image_and_flow_placeholders(self, shape, image_type, flow_type):
86    batch_size, height, width, numchannels = shape
87    image_shape = [batch_size, height, width, numchannels]
88    flow_shape = [batch_size, height, width, 2]
89
90    tf_type = {
91        'float16': dtypes.half,
92        'float32': dtypes.float32,
93        'float64': dtypes.float64
94    }
95
96    image = array_ops.placeholder(dtype=tf_type[image_type], shape=image_shape)
97
98    flows = array_ops.placeholder(dtype=tf_type[flow_type], shape=flow_shape)
99    return image, flows
100
101  def get_random_image_and_flows(self, shape, image_type, flow_type):
102    batch_size, height, width, numchannels = shape
103    image_shape = [batch_size, height, width, numchannels]
104    image = np.random.normal(size=image_shape)
105    flow_shape = [batch_size, height, width, 2]
106    flows = np.random.normal(size=flow_shape) * 3
107    return image.astype(image_type), flows.astype(flow_type)
108
109  def assert_correct_interpolation_value(self,
110                                         image,
111                                         flows,
112                                         pred_interpolation,
113                                         batch_index,
114                                         y_index,
115                                         x_index,
116                                         low_precision=False):
117    """Assert that the tf interpolation matches hand-computed value."""
118
119    height = image.shape[1]
120    width = image.shape[2]
121    displacement = flows[batch_index, y_index, x_index, :]
122    float_y = y_index - displacement[0]
123    float_x = x_index - displacement[1]
124    floor_y = max(min(height - 2, math.floor(float_y)), 0)
125    floor_x = max(min(width - 2, math.floor(float_x)), 0)
126    ceil_y = floor_y + 1
127    ceil_x = floor_x + 1
128
129    alpha_y = min(max(0.0, float_y - floor_y), 1.0)
130    alpha_x = min(max(0.0, float_x - floor_x), 1.0)
131
132    floor_y = int(floor_y)
133    floor_x = int(floor_x)
134    ceil_y = int(ceil_y)
135    ceil_x = int(ceil_x)
136
137    top_left = image[batch_index, floor_y, floor_x, :]
138    top_right = image[batch_index, floor_y, ceil_x, :]
139    bottom_left = image[batch_index, ceil_y, floor_x, :]
140    bottom_right = image[batch_index, ceil_y, ceil_x, :]
141
142    interp_top = alpha_x * (top_right - top_left) + top_left
143    interp_bottom = alpha_x * (bottom_right - bottom_left) + bottom_left
144    interp = alpha_y * (interp_bottom - interp_top) + interp_top
145    atol = 1e-6
146    rtol = 1e-6
147    if low_precision:
148      atol = 1e-2
149      rtol = 1e-3
150    self.assertAllClose(
151        interp,
152        pred_interpolation[batch_index, y_index, x_index, :],
153        atol=atol,
154        rtol=rtol)
155
156  def check_zero_flow_correctness(self, shape, image_type, flow_type):
157    """Assert using zero flows doesn't change the input image."""
158
159    image, flows = self.get_image_and_flow_placeholders(shape, image_type,
160                                                        flow_type)
161    interp = dense_image_warp.dense_image_warp(image, flows)
162
163    with self.cached_session() as sess:
164      rand_image, rand_flows = self.get_random_image_and_flows(
165          shape, image_type, flow_type)
166      rand_flows *= 0
167
168      predicted_interpolation = sess.run(
169          interp, feed_dict={
170              image: rand_image,
171              flows: rand_flows
172          })
173      self.assertAllClose(rand_image, predicted_interpolation)
174
175  def test_zero_flows(self):
176    """Apply check_zero_flow_correctness() for a few sizes and types."""
177
178    shapes_to_try = [[3, 4, 5, 6], [1, 2, 2, 1]]
179    for shape in shapes_to_try:
180      self.check_zero_flow_correctness(
181          shape, image_type='float32', flow_type='float32')
182
183  def check_interpolation_correctness(self,
184                                      shape,
185                                      image_type,
186                                      flow_type,
187                                      num_probes=5):
188    """Interpolate, and then assert correctness for a few query locations."""
189
190    image, flows = self.get_image_and_flow_placeholders(shape, image_type,
191                                                        flow_type)
192    interp = dense_image_warp.dense_image_warp(image, flows)
193    low_precision = image_type == 'float16' or flow_type == 'float16'
194    with self.cached_session() as sess:
195      rand_image, rand_flows = self.get_random_image_and_flows(
196          shape, image_type, flow_type)
197
198      pred_interpolation = sess.run(
199          interp, feed_dict={
200              image: rand_image,
201              flows: rand_flows
202          })
203
204      for _ in range(num_probes):
205        batch_index = np.random.randint(0, shape[0])
206        y_index = np.random.randint(0, shape[1])
207        x_index = np.random.randint(0, shape[2])
208
209        self.assert_correct_interpolation_value(
210            rand_image,
211            rand_flows,
212            pred_interpolation,
213            batch_index,
214            y_index,
215            x_index,
216            low_precision=low_precision)
217
218  def test_interpolation(self):
219    """Apply check_interpolation_correctness() for a few sizes and types."""
220
221    shapes_to_try = [[3, 4, 5, 6], [1, 5, 5, 3], [1, 2, 2, 1]]
222    for im_type in ['float32', 'float64', 'float16']:
223      for flow_type in ['float32', 'float64', 'float16']:
224        for shape in shapes_to_try:
225          self.check_interpolation_correctness(shape, im_type, flow_type)
226
227  def test_gradients_exist(self):
228    """Check that backprop can run.
229
230    The correctness of the gradients is assumed, since the forward propagation
231    is tested to be correct and we only use built-in tf ops.
232    However, we perform a simple test to make sure that backprop can actually
233    run. We treat the flows as a tf.Variable and optimize them to minimize
234    the difference between the interpolated image and the input image.
235    """
236
237    batch_size, height, width, numchannels = [4, 5, 6, 7]
238    image_shape = [batch_size, height, width, numchannels]
239    image = random_ops.random_normal(image_shape)
240    flow_shape = [batch_size, height, width, 2]
241    init_flows = np.float32(np.random.normal(size=flow_shape) * 0.25)
242    flows = variables.Variable(init_flows)
243
244    interp = dense_image_warp.dense_image_warp(image, flows)
245    loss = math_ops.reduce_mean(math_ops.square(interp - image))
246
247    optimizer = adam.AdamOptimizer(1.0)
248    grad = gradients.gradients(loss, [flows])
249    opt_func = optimizer.apply_gradients(zip(grad, [flows]))
250    init_op = variables.global_variables_initializer()
251
252    with self.cached_session() as sess:
253      sess.run(init_op)
254      for _ in range(10):
255        sess.run(opt_func)
256
257  def test_size_exception(self):
258    """Make sure it throws an exception for images that are too small."""
259
260    shape = [1, 2, 1, 1]
261    msg = 'Should have raised an exception for invalid image size'
262    with self.assertRaises(errors.InvalidArgumentError, msg=msg):
263      self.check_interpolation_correctness(shape, 'float32', 'float32')
264
265
266if __name__ == '__main__':
267  googletest.main()
268