• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Image warping using sparse flow defined at control points."""
16from __future__ import absolute_import
17from __future__ import division
18from __future__ import print_function
19
20import numpy as np
21
22from tensorflow.contrib.image.python.ops import dense_image_warp
23from tensorflow.contrib.image.python.ops import interpolate_spline
24
25from tensorflow.python.framework import constant_op
26from tensorflow.python.framework import ops
27from tensorflow.python.framework import tensor_shape
28from tensorflow.python.ops import array_ops
29
30
31def _get_grid_locations(image_height, image_width):
32  """Wrapper for np.meshgrid."""
33
34  y_range = np.linspace(0, image_height - 1, image_height)
35  x_range = np.linspace(0, image_width - 1, image_width)
36  y_grid, x_grid = np.meshgrid(y_range, x_range, indexing='ij')
37  return np.stack((y_grid, x_grid), -1)
38
39
40def _expand_to_minibatch(np_array, batch_size):
41  """Tile arbitrarily-sized np_array to include new batch dimension."""
42  tiles = [batch_size] + [1] * np_array.ndim
43  return np.tile(np.expand_dims(np_array, 0), tiles)
44
45
46def _get_boundary_locations(image_height, image_width, num_points_per_edge):
47  """Compute evenly-spaced indices along edge of image."""
48  y_range = np.linspace(0, image_height - 1, num_points_per_edge + 2)
49  x_range = np.linspace(0, image_width - 1, num_points_per_edge + 2)
50  ys, xs = np.meshgrid(y_range, x_range, indexing='ij')
51  is_boundary = np.logical_or(
52      np.logical_or(xs == 0, xs == image_width - 1),
53      np.logical_or(ys == 0, ys == image_height - 1))
54  return np.stack([ys[is_boundary], xs[is_boundary]], axis=-1)
55
56
57def _add_zero_flow_controls_at_boundary(control_point_locations,
58                                        control_point_flows, image_height,
59                                        image_width, boundary_points_per_edge):
60  """Add control points for zero-flow boundary conditions.
61
62   Augment the set of control points with extra points on the
63   boundary of the image that have zero flow.
64
65  Args:
66    control_point_locations: input control points
67    control_point_flows: their flows
68    image_height: image height
69    image_width: image width
70    boundary_points_per_edge: number of points to add in the middle of each
71                           edge (not including the corners).
72                           The total number of points added is
73                           4 + 4*(boundary_points_per_edge).
74
75  Returns:
76    merged_control_point_locations: augmented set of control point locations
77    merged_control_point_flows: augmented set of control point flows
78  """
79
80  batch_size = tensor_shape.dimension_value(control_point_locations.shape[0])
81
82  boundary_point_locations = _get_boundary_locations(image_height, image_width,
83                                                     boundary_points_per_edge)
84
85  boundary_point_flows = np.zeros([boundary_point_locations.shape[0], 2])
86
87  type_to_use = control_point_locations.dtype
88  boundary_point_locations = constant_op.constant(
89      _expand_to_minibatch(boundary_point_locations, batch_size),
90      dtype=type_to_use)
91
92  boundary_point_flows = constant_op.constant(
93      _expand_to_minibatch(boundary_point_flows, batch_size), dtype=type_to_use)
94
95  merged_control_point_locations = array_ops.concat(
96      [control_point_locations, boundary_point_locations], 1)
97
98  merged_control_point_flows = array_ops.concat(
99      [control_point_flows, boundary_point_flows], 1)
100
101  return merged_control_point_locations, merged_control_point_flows
102
103
104def sparse_image_warp(image,
105                      source_control_point_locations,
106                      dest_control_point_locations,
107                      interpolation_order=2,
108                      regularization_weight=0.0,
109                      num_boundary_points=0,
110                      name='sparse_image_warp'):
111  """Image warping using correspondences between sparse control points.
112
113  Apply a non-linear warp to the image, where the warp is specified by
114  the source and destination locations of a (potentially small) number of
115  control points. First, we use a polyharmonic spline
116  (`tf.contrib.image.interpolate_spline`) to interpolate the displacements
117  between the corresponding control points to a dense flow field.
118  Then, we warp the image using this dense flow field
119  (`tf.contrib.image.dense_image_warp`).
120
121  Let t index our control points. For regularization_weight=0, we have:
122  warped_image[b, dest_control_point_locations[b, t, 0],
123                  dest_control_point_locations[b, t, 1], :] =
124  image[b, source_control_point_locations[b, t, 0],
125           source_control_point_locations[b, t, 1], :].
126
127  For regularization_weight > 0, this condition is met approximately, since
128  regularized interpolation trades off smoothness of the interpolant vs.
129  reconstruction of the interpolant at the control points.
130  See `tf.contrib.image.interpolate_spline` for further documentation of the
131  interpolation_order and regularization_weight arguments.
132
133
134  Args:
135    image: `[batch, height, width, channels]` float `Tensor`
136    source_control_point_locations: `[batch, num_control_points, 2]` float
137      `Tensor`
138    dest_control_point_locations: `[batch, num_control_points, 2]` float
139      `Tensor`
140    interpolation_order: polynomial order used by the spline interpolation
141    regularization_weight: weight on smoothness regularizer in interpolation
142    num_boundary_points: How many zero-flow boundary points to include at
143      each image edge.Usage:
144        num_boundary_points=0: don't add zero-flow points
145        num_boundary_points=1: 4 corners of the image
146        num_boundary_points=2: 4 corners and one in the middle of each edge
147          (8 points total)
148        num_boundary_points=n: 4 corners and n-1 along each edge
149    name: A name for the operation (optional).
150
151    Note that image and offsets can be of type tf.half, tf.float32, or
152    tf.float64, and do not necessarily have to be the same type.
153
154  Returns:
155    warped_image: `[batch, height, width, channels]` float `Tensor` with same
156      type as input image.
157    flow_field: `[batch, height, width, 2]` float `Tensor` containing the dense
158      flow field produced by the interpolation.
159  """
160
161  image = ops.convert_to_tensor(image)
162  source_control_point_locations = ops.convert_to_tensor(
163      source_control_point_locations)
164  dest_control_point_locations = ops.convert_to_tensor(
165      dest_control_point_locations)
166
167  control_point_flows = (
168      dest_control_point_locations - source_control_point_locations)
169
170  clamp_boundaries = num_boundary_points > 0
171  boundary_points_per_edge = num_boundary_points - 1
172
173  with ops.name_scope(name):
174
175    batch_size, image_height, image_width, _ = image.get_shape().as_list()
176
177    # This generates the dense locations where the interpolant
178    # will be evaluated.
179    grid_locations = _get_grid_locations(image_height, image_width)
180
181    flattened_grid_locations = np.reshape(grid_locations,
182                                          [image_height * image_width, 2])
183
184    flattened_grid_locations = constant_op.constant(
185        _expand_to_minibatch(flattened_grid_locations, batch_size), image.dtype)
186
187    if clamp_boundaries:
188      (dest_control_point_locations,
189       control_point_flows) = _add_zero_flow_controls_at_boundary(
190           dest_control_point_locations, control_point_flows, image_height,
191           image_width, boundary_points_per_edge)
192
193    flattened_flows = interpolate_spline.interpolate_spline(
194        dest_control_point_locations, control_point_flows,
195        flattened_grid_locations, interpolation_order, regularization_weight)
196
197    dense_flows = array_ops.reshape(flattened_flows,
198                                    [batch_size, image_height, image_width, 2])
199
200    warped_image = dense_image_warp.dense_image_warp(image, dense_flows)
201
202    return warped_image, dense_flows
203