1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Image warping using sparse flow defined at control points.""" 16from __future__ import absolute_import 17from __future__ import division 18from __future__ import print_function 19 20import numpy as np 21 22from tensorflow.contrib.image.python.ops import dense_image_warp 23from tensorflow.contrib.image.python.ops import interpolate_spline 24 25from tensorflow.python.framework import constant_op 26from tensorflow.python.framework import ops 27from tensorflow.python.framework import tensor_shape 28from tensorflow.python.ops import array_ops 29 30 31def _get_grid_locations(image_height, image_width): 32 """Wrapper for np.meshgrid.""" 33 34 y_range = np.linspace(0, image_height - 1, image_height) 35 x_range = np.linspace(0, image_width - 1, image_width) 36 y_grid, x_grid = np.meshgrid(y_range, x_range, indexing='ij') 37 return np.stack((y_grid, x_grid), -1) 38 39 40def _expand_to_minibatch(np_array, batch_size): 41 """Tile arbitrarily-sized np_array to include new batch dimension.""" 42 tiles = [batch_size] + [1] * np_array.ndim 43 return np.tile(np.expand_dims(np_array, 0), tiles) 44 45 46def _get_boundary_locations(image_height, image_width, num_points_per_edge): 47 """Compute evenly-spaced indices along edge of image.""" 48 y_range = np.linspace(0, image_height - 1, num_points_per_edge + 2) 49 x_range = np.linspace(0, image_width - 1, num_points_per_edge + 2) 50 ys, xs = np.meshgrid(y_range, x_range, indexing='ij') 51 is_boundary = np.logical_or( 52 np.logical_or(xs == 0, xs == image_width - 1), 53 np.logical_or(ys == 0, ys == image_height - 1)) 54 return np.stack([ys[is_boundary], xs[is_boundary]], axis=-1) 55 56 57def _add_zero_flow_controls_at_boundary(control_point_locations, 58 control_point_flows, image_height, 59 image_width, boundary_points_per_edge): 60 """Add control points for zero-flow boundary conditions. 61 62 Augment the set of control points with extra points on the 63 boundary of the image that have zero flow. 64 65 Args: 66 control_point_locations: input control points 67 control_point_flows: their flows 68 image_height: image height 69 image_width: image width 70 boundary_points_per_edge: number of points to add in the middle of each 71 edge (not including the corners). 72 The total number of points added is 73 4 + 4*(boundary_points_per_edge). 74 75 Returns: 76 merged_control_point_locations: augmented set of control point locations 77 merged_control_point_flows: augmented set of control point flows 78 """ 79 80 batch_size = tensor_shape.dimension_value(control_point_locations.shape[0]) 81 82 boundary_point_locations = _get_boundary_locations(image_height, image_width, 83 boundary_points_per_edge) 84 85 boundary_point_flows = np.zeros([boundary_point_locations.shape[0], 2]) 86 87 type_to_use = control_point_locations.dtype 88 boundary_point_locations = constant_op.constant( 89 _expand_to_minibatch(boundary_point_locations, batch_size), 90 dtype=type_to_use) 91 92 boundary_point_flows = constant_op.constant( 93 _expand_to_minibatch(boundary_point_flows, batch_size), dtype=type_to_use) 94 95 merged_control_point_locations = array_ops.concat( 96 [control_point_locations, boundary_point_locations], 1) 97 98 merged_control_point_flows = array_ops.concat( 99 [control_point_flows, boundary_point_flows], 1) 100 101 return merged_control_point_locations, merged_control_point_flows 102 103 104def sparse_image_warp(image, 105 source_control_point_locations, 106 dest_control_point_locations, 107 interpolation_order=2, 108 regularization_weight=0.0, 109 num_boundary_points=0, 110 name='sparse_image_warp'): 111 """Image warping using correspondences between sparse control points. 112 113 Apply a non-linear warp to the image, where the warp is specified by 114 the source and destination locations of a (potentially small) number of 115 control points. First, we use a polyharmonic spline 116 (`tf.contrib.image.interpolate_spline`) to interpolate the displacements 117 between the corresponding control points to a dense flow field. 118 Then, we warp the image using this dense flow field 119 (`tf.contrib.image.dense_image_warp`). 120 121 Let t index our control points. For regularization_weight=0, we have: 122 warped_image[b, dest_control_point_locations[b, t, 0], 123 dest_control_point_locations[b, t, 1], :] = 124 image[b, source_control_point_locations[b, t, 0], 125 source_control_point_locations[b, t, 1], :]. 126 127 For regularization_weight > 0, this condition is met approximately, since 128 regularized interpolation trades off smoothness of the interpolant vs. 129 reconstruction of the interpolant at the control points. 130 See `tf.contrib.image.interpolate_spline` for further documentation of the 131 interpolation_order and regularization_weight arguments. 132 133 134 Args: 135 image: `[batch, height, width, channels]` float `Tensor` 136 source_control_point_locations: `[batch, num_control_points, 2]` float 137 `Tensor` 138 dest_control_point_locations: `[batch, num_control_points, 2]` float 139 `Tensor` 140 interpolation_order: polynomial order used by the spline interpolation 141 regularization_weight: weight on smoothness regularizer in interpolation 142 num_boundary_points: How many zero-flow boundary points to include at 143 each image edge.Usage: 144 num_boundary_points=0: don't add zero-flow points 145 num_boundary_points=1: 4 corners of the image 146 num_boundary_points=2: 4 corners and one in the middle of each edge 147 (8 points total) 148 num_boundary_points=n: 4 corners and n-1 along each edge 149 name: A name for the operation (optional). 150 151 Note that image and offsets can be of type tf.half, tf.float32, or 152 tf.float64, and do not necessarily have to be the same type. 153 154 Returns: 155 warped_image: `[batch, height, width, channels]` float `Tensor` with same 156 type as input image. 157 flow_field: `[batch, height, width, 2]` float `Tensor` containing the dense 158 flow field produced by the interpolation. 159 """ 160 161 image = ops.convert_to_tensor(image) 162 source_control_point_locations = ops.convert_to_tensor( 163 source_control_point_locations) 164 dest_control_point_locations = ops.convert_to_tensor( 165 dest_control_point_locations) 166 167 control_point_flows = ( 168 dest_control_point_locations - source_control_point_locations) 169 170 clamp_boundaries = num_boundary_points > 0 171 boundary_points_per_edge = num_boundary_points - 1 172 173 with ops.name_scope(name): 174 175 batch_size, image_height, image_width, _ = image.get_shape().as_list() 176 177 # This generates the dense locations where the interpolant 178 # will be evaluated. 179 grid_locations = _get_grid_locations(image_height, image_width) 180 181 flattened_grid_locations = np.reshape(grid_locations, 182 [image_height * image_width, 2]) 183 184 flattened_grid_locations = constant_op.constant( 185 _expand_to_minibatch(flattened_grid_locations, batch_size), image.dtype) 186 187 if clamp_boundaries: 188 (dest_control_point_locations, 189 control_point_flows) = _add_zero_flow_controls_at_boundary( 190 dest_control_point_locations, control_point_flows, image_height, 191 image_width, boundary_points_per_edge) 192 193 flattened_flows = interpolate_spline.interpolate_spline( 194 dest_control_point_locations, control_point_flows, 195 flattened_grid_locations, interpolation_order, regularization_weight) 196 197 dense_flows = array_ops.reshape(flattened_flows, 198 [batch_size, image_height, image_width, 2]) 199 200 warped_image = dense_image_warp.dense_image_warp(image, dense_flows) 201 202 return warped_image, dense_flows 203