1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Image warping using per-pixel flow vectors.""" 16from __future__ import absolute_import 17from __future__ import division 18from __future__ import print_function 19 20import numpy as np 21 22from tensorflow.python.framework import constant_op 23from tensorflow.python.framework import dtypes 24from tensorflow.python.framework import ops 25 26from tensorflow.python.ops import array_ops 27from tensorflow.python.ops import check_ops 28from tensorflow.python.ops import math_ops 29 30 31def _interpolate_bilinear(grid, 32 query_points, 33 name='interpolate_bilinear', 34 indexing='ij'): 35 """Similar to Matlab's interp2 function. 36 37 Finds values for query points on a grid using bilinear interpolation. 38 39 Args: 40 grid: a 4-D float `Tensor` of shape `[batch, height, width, channels]`. 41 query_points: a 3-D float `Tensor` of N points with shape `[batch, N, 2]`. 42 name: a name for the operation (optional). 43 indexing: whether the query points are specified as row and column (ij), 44 or Cartesian coordinates (xy). 45 46 Returns: 47 values: a 3-D `Tensor` with shape `[batch, N, channels]` 48 49 Raises: 50 ValueError: if the indexing mode is invalid, or if the shape of the inputs 51 invalid. 52 """ 53 if indexing != 'ij' and indexing != 'xy': 54 raise ValueError('Indexing mode must be \'ij\' or \'xy\'') 55 56 with ops.name_scope(name): 57 grid = ops.convert_to_tensor(grid) 58 query_points = ops.convert_to_tensor(query_points) 59 shape = grid.get_shape().as_list() 60 if len(shape) != 4: 61 msg = 'Grid must be 4 dimensional. Received size: ' 62 raise ValueError(msg + str(grid.get_shape())) 63 64 batch_size, height, width, channels = (array_ops.shape(grid)[0], 65 array_ops.shape(grid)[1], 66 array_ops.shape(grid)[2], 67 array_ops.shape(grid)[3]) 68 69 shape = [batch_size, height, width, channels] 70 query_type = query_points.dtype 71 grid_type = grid.dtype 72 73 with ops.control_dependencies([ 74 check_ops.assert_equal( 75 len(query_points.get_shape()), 76 3, 77 message='Query points must be 3 dimensional.'), 78 check_ops.assert_equal( 79 array_ops.shape(query_points)[2], 80 2, 81 message='Query points must be size 2 in dim 2.') 82 ]): 83 num_queries = array_ops.shape(query_points)[1] 84 85 with ops.control_dependencies([ 86 check_ops.assert_greater_equal( 87 height, 2, message='Grid height must be at least 2.'), 88 check_ops.assert_greater_equal( 89 width, 2, message='Grid width must be at least 2.') 90 ]): 91 alphas = [] 92 floors = [] 93 ceils = [] 94 index_order = [0, 1] if indexing == 'ij' else [1, 0] 95 unstacked_query_points = array_ops.unstack(query_points, axis=2) 96 97 for dim in index_order: 98 with ops.name_scope('dim-' + str(dim)): 99 queries = unstacked_query_points[dim] 100 101 size_in_indexing_dimension = shape[dim + 1] 102 103 # max_floor is size_in_indexing_dimension - 2 so that max_floor + 1 104 # is still a valid index into the grid. 105 max_floor = math_ops.cast(size_in_indexing_dimension - 2, query_type) 106 min_floor = constant_op.constant(0.0, dtype=query_type) 107 floor = math_ops.minimum( 108 math_ops.maximum(min_floor, math_ops.floor(queries)), max_floor) 109 int_floor = math_ops.cast(floor, dtypes.int32) 110 floors.append(int_floor) 111 ceil = int_floor + 1 112 ceils.append(ceil) 113 114 # alpha has the same type as the grid, as we will directly use alpha 115 # when taking linear combinations of pixel values from the image. 116 alpha = math_ops.cast(queries - floor, grid_type) 117 min_alpha = constant_op.constant(0.0, dtype=grid_type) 118 max_alpha = constant_op.constant(1.0, dtype=grid_type) 119 alpha = math_ops.minimum(math_ops.maximum(min_alpha, alpha), max_alpha) 120 121 # Expand alpha to [b, n, 1] so we can use broadcasting 122 # (since the alpha values don't depend on the channel). 123 alpha = array_ops.expand_dims(alpha, 2) 124 alphas.append(alpha) 125 126 with ops.control_dependencies([ 127 check_ops.assert_less_equal( 128 math_ops.cast(batch_size * height * width, dtype=dtypes.float32), 129 np.iinfo(np.int32).max / 8, 130 message="""The image size or batch size is sufficiently large 131 that the linearized addresses used by array_ops.gather 132 may exceed the int32 limit.""") 133 ]): 134 flattened_grid = array_ops.reshape( 135 grid, [batch_size * height * width, channels]) 136 batch_offsets = array_ops.reshape( 137 math_ops.range(batch_size) * height * width, [batch_size, 1]) 138 139 # This wraps array_ops.gather. We reshape the image data such that the 140 # batch, y, and x coordinates are pulled into the first dimension. 141 # Then we gather. Finally, we reshape the output back. It's possible this 142 # code would be made simpler by using array_ops.gather_nd. 143 def gather(y_coords, x_coords, name): 144 with ops.name_scope('gather-' + name): 145 linear_coordinates = batch_offsets + y_coords * width + x_coords 146 gathered_values = array_ops.gather(flattened_grid, linear_coordinates) 147 return array_ops.reshape(gathered_values, 148 [batch_size, num_queries, channels]) 149 150 # grab the pixel values in the 4 corners around each query point 151 top_left = gather(floors[0], floors[1], 'top_left') 152 top_right = gather(floors[0], ceils[1], 'top_right') 153 bottom_left = gather(ceils[0], floors[1], 'bottom_left') 154 bottom_right = gather(ceils[0], ceils[1], 'bottom_right') 155 156 # now, do the actual interpolation 157 with ops.name_scope('interpolate'): 158 interp_top = alphas[1] * (top_right - top_left) + top_left 159 interp_bottom = alphas[1] * (bottom_right - bottom_left) + bottom_left 160 interp = alphas[0] * (interp_bottom - interp_top) + interp_top 161 162 return interp 163 164 165def dense_image_warp(image, flow, name='dense_image_warp'): 166 """Image warping using per-pixel flow vectors. 167 168 Apply a non-linear warp to the image, where the warp is specified by a dense 169 flow field of offset vectors that define the correspondences of pixel values 170 in the output image back to locations in the source image. Specifically, the 171 pixel value at output[b, j, i, c] is 172 images[b, j - flow[b, j, i, 0], i - flow[b, j, i, 1], c]. 173 174 The locations specified by this formula do not necessarily map to an int 175 index. Therefore, the pixel value is obtained by bilinear 176 interpolation of the 4 nearest pixels around 177 (b, j - flow[b, j, i, 0], i - flow[b, j, i, 1]). For locations outside 178 of the image, we use the nearest pixel values at the image boundary. 179 180 181 Args: 182 image: 4-D float `Tensor` with shape `[batch, height, width, channels]`. 183 flow: A 4-D float `Tensor` with shape `[batch, height, width, 2]`. 184 name: A name for the operation (optional). 185 186 Note that image and flow can be of type tf.half, tf.float32, or tf.float64, 187 and do not necessarily have to be the same type. 188 189 Returns: 190 A 4-D float `Tensor` with shape`[batch, height, width, channels]` 191 and same type as input image. 192 193 Raises: 194 ValueError: if height < 2 or width < 2 or the inputs have the wrong number 195 of dimensions. 196 """ 197 with ops.name_scope(name): 198 batch_size, height, width, channels = (array_ops.shape(image)[0], 199 array_ops.shape(image)[1], 200 array_ops.shape(image)[2], 201 array_ops.shape(image)[3]) 202 203 # The flow is defined on the image grid. Turn the flow into a list of query 204 # points in the grid space. 205 grid_x, grid_y = array_ops.meshgrid( 206 math_ops.range(width), math_ops.range(height)) 207 stacked_grid = math_ops.cast( 208 array_ops.stack([grid_y, grid_x], axis=2), flow.dtype) 209 batched_grid = array_ops.expand_dims(stacked_grid, axis=0) 210 query_points_on_grid = batched_grid - flow 211 query_points_flattened = array_ops.reshape(query_points_on_grid, 212 [batch_size, height * width, 2]) 213 # Compute values at the query points, then reshape the result back to the 214 # image grid. 215 interpolated = _interpolate_bilinear(image, query_points_flattened) 216 interpolated = array_ops.reshape(interpolated, 217 [batch_size, height, width, channels]) 218 return interpolated 219