• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Image warping using per-pixel flow vectors."""
16from __future__ import absolute_import
17from __future__ import division
18from __future__ import print_function
19
20import numpy as np
21
22from tensorflow.python.framework import constant_op
23from tensorflow.python.framework import dtypes
24from tensorflow.python.framework import ops
25
26from tensorflow.python.ops import array_ops
27from tensorflow.python.ops import check_ops
28from tensorflow.python.ops import math_ops
29
30
31def _interpolate_bilinear(grid,
32                          query_points,
33                          name='interpolate_bilinear',
34                          indexing='ij'):
35  """Similar to Matlab's interp2 function.
36
37  Finds values for query points on a grid using bilinear interpolation.
38
39  Args:
40    grid: a 4-D float `Tensor` of shape `[batch, height, width, channels]`.
41    query_points: a 3-D float `Tensor` of N points with shape `[batch, N, 2]`.
42    name: a name for the operation (optional).
43    indexing: whether the query points are specified as row and column (ij),
44      or Cartesian coordinates (xy).
45
46  Returns:
47    values: a 3-D `Tensor` with shape `[batch, N, channels]`
48
49  Raises:
50    ValueError: if the indexing mode is invalid, or if the shape of the inputs
51      invalid.
52  """
53  if indexing != 'ij' and indexing != 'xy':
54    raise ValueError('Indexing mode must be \'ij\' or \'xy\'')
55
56  with ops.name_scope(name):
57    grid = ops.convert_to_tensor(grid)
58    query_points = ops.convert_to_tensor(query_points)
59    shape = grid.get_shape().as_list()
60    if len(shape) != 4:
61      msg = 'Grid must be 4 dimensional. Received size: '
62      raise ValueError(msg + str(grid.get_shape()))
63
64    batch_size, height, width, channels = (array_ops.shape(grid)[0],
65                                           array_ops.shape(grid)[1],
66                                           array_ops.shape(grid)[2],
67                                           array_ops.shape(grid)[3])
68
69    shape = [batch_size, height, width, channels]
70    query_type = query_points.dtype
71    grid_type = grid.dtype
72
73    with ops.control_dependencies([
74        check_ops.assert_equal(
75            len(query_points.get_shape()),
76            3,
77            message='Query points must be 3 dimensional.'),
78        check_ops.assert_equal(
79            array_ops.shape(query_points)[2],
80            2,
81            message='Query points must be size 2 in dim 2.')
82    ]):
83      num_queries = array_ops.shape(query_points)[1]
84
85    with ops.control_dependencies([
86        check_ops.assert_greater_equal(
87            height, 2, message='Grid height must be at least 2.'),
88        check_ops.assert_greater_equal(
89            width, 2, message='Grid width must be at least 2.')
90    ]):
91      alphas = []
92      floors = []
93      ceils = []
94      index_order = [0, 1] if indexing == 'ij' else [1, 0]
95      unstacked_query_points = array_ops.unstack(query_points, axis=2)
96
97    for dim in index_order:
98      with ops.name_scope('dim-' + str(dim)):
99        queries = unstacked_query_points[dim]
100
101        size_in_indexing_dimension = shape[dim + 1]
102
103        # max_floor is size_in_indexing_dimension - 2 so that max_floor + 1
104        # is still a valid index into the grid.
105        max_floor = math_ops.cast(size_in_indexing_dimension - 2, query_type)
106        min_floor = constant_op.constant(0.0, dtype=query_type)
107        floor = math_ops.minimum(
108            math_ops.maximum(min_floor, math_ops.floor(queries)), max_floor)
109        int_floor = math_ops.cast(floor, dtypes.int32)
110        floors.append(int_floor)
111        ceil = int_floor + 1
112        ceils.append(ceil)
113
114        # alpha has the same type as the grid, as we will directly use alpha
115        # when taking linear combinations of pixel values from the image.
116        alpha = math_ops.cast(queries - floor, grid_type)
117        min_alpha = constant_op.constant(0.0, dtype=grid_type)
118        max_alpha = constant_op.constant(1.0, dtype=grid_type)
119        alpha = math_ops.minimum(math_ops.maximum(min_alpha, alpha), max_alpha)
120
121        # Expand alpha to [b, n, 1] so we can use broadcasting
122        # (since the alpha values don't depend on the channel).
123        alpha = array_ops.expand_dims(alpha, 2)
124        alphas.append(alpha)
125
126    with ops.control_dependencies([
127        check_ops.assert_less_equal(
128            math_ops.cast(batch_size * height * width, dtype=dtypes.float32),
129            np.iinfo(np.int32).max / 8,
130            message="""The image size or batch size is sufficiently large
131                       that the linearized addresses used by array_ops.gather
132                       may exceed the int32 limit.""")
133    ]):
134      flattened_grid = array_ops.reshape(
135          grid, [batch_size * height * width, channels])
136      batch_offsets = array_ops.reshape(
137          math_ops.range(batch_size) * height * width, [batch_size, 1])
138
139    # This wraps array_ops.gather. We reshape the image data such that the
140    # batch, y, and x coordinates are pulled into the first dimension.
141    # Then we gather. Finally, we reshape the output back. It's possible this
142    # code would be made simpler by using array_ops.gather_nd.
143    def gather(y_coords, x_coords, name):
144      with ops.name_scope('gather-' + name):
145        linear_coordinates = batch_offsets + y_coords * width + x_coords
146        gathered_values = array_ops.gather(flattened_grid, linear_coordinates)
147        return array_ops.reshape(gathered_values,
148                                 [batch_size, num_queries, channels])
149
150    # grab the pixel values in the 4 corners around each query point
151    top_left = gather(floors[0], floors[1], 'top_left')
152    top_right = gather(floors[0], ceils[1], 'top_right')
153    bottom_left = gather(ceils[0], floors[1], 'bottom_left')
154    bottom_right = gather(ceils[0], ceils[1], 'bottom_right')
155
156    # now, do the actual interpolation
157    with ops.name_scope('interpolate'):
158      interp_top = alphas[1] * (top_right - top_left) + top_left
159      interp_bottom = alphas[1] * (bottom_right - bottom_left) + bottom_left
160      interp = alphas[0] * (interp_bottom - interp_top) + interp_top
161
162    return interp
163
164
165def dense_image_warp(image, flow, name='dense_image_warp'):
166  """Image warping using per-pixel flow vectors.
167
168  Apply a non-linear warp to the image, where the warp is specified by a dense
169  flow field of offset vectors that define the correspondences of pixel values
170  in the output image back to locations in the  source image. Specifically, the
171  pixel value at output[b, j, i, c] is
172  images[b, j - flow[b, j, i, 0], i - flow[b, j, i, 1], c].
173
174  The locations specified by this formula do not necessarily map to an int
175  index. Therefore, the pixel value is obtained by bilinear
176  interpolation of the 4 nearest pixels around
177  (b, j - flow[b, j, i, 0], i - flow[b, j, i, 1]). For locations outside
178  of the image, we use the nearest pixel values at the image boundary.
179
180
181  Args:
182    image: 4-D float `Tensor` with shape `[batch, height, width, channels]`.
183    flow: A 4-D float `Tensor` with shape `[batch, height, width, 2]`.
184    name: A name for the operation (optional).
185
186    Note that image and flow can be of type tf.half, tf.float32, or tf.float64,
187    and do not necessarily have to be the same type.
188
189  Returns:
190    A 4-D float `Tensor` with shape`[batch, height, width, channels]`
191      and same type as input image.
192
193  Raises:
194    ValueError: if height < 2 or width < 2 or the inputs have the wrong number
195                of dimensions.
196  """
197  with ops.name_scope(name):
198    batch_size, height, width, channels = (array_ops.shape(image)[0],
199                                           array_ops.shape(image)[1],
200                                           array_ops.shape(image)[2],
201                                           array_ops.shape(image)[3])
202
203    # The flow is defined on the image grid. Turn the flow into a list of query
204    # points in the grid space.
205    grid_x, grid_y = array_ops.meshgrid(
206        math_ops.range(width), math_ops.range(height))
207    stacked_grid = math_ops.cast(
208        array_ops.stack([grid_y, grid_x], axis=2), flow.dtype)
209    batched_grid = array_ops.expand_dims(stacked_grid, axis=0)
210    query_points_on_grid = batched_grid - flow
211    query_points_flattened = array_ops.reshape(query_points_on_grid,
212                                               [batch_size, height * width, 2])
213    # Compute values at the query points, then reshape the result back to the
214    # image grid.
215    interpolated = _interpolate_bilinear(image, query_points_flattened)
216    interpolated = array_ops.reshape(interpolated,
217                                     [batch_size, height, width, channels])
218    return interpolated
219