1# Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Python layer for image_ops.""" 16from __future__ import absolute_import 17from __future__ import division 18from __future__ import print_function 19 20from tensorflow.python.eager import context 21from tensorflow.contrib.image.ops import gen_image_ops 22from tensorflow.contrib.util import loader 23from tensorflow.python.framework import common_shapes 24from tensorflow.python.framework import constant_op 25from tensorflow.python.framework import dtypes 26from tensorflow.python.framework import ops 27from tensorflow.python.framework import tensor_util 28from tensorflow.python.ops import array_ops 29from tensorflow.python.ops import control_flow_ops 30from tensorflow.python.ops import linalg_ops 31from tensorflow.python.ops import math_ops 32from tensorflow.python.platform import resource_loader 33 34_image_ops_so = loader.load_op_library( 35 resource_loader.get_path_to_datafile("_image_ops.so")) 36 37_IMAGE_DTYPES = set( 38 [dtypes.uint8, dtypes.int32, dtypes.int64, 39 dtypes.float16, dtypes.float32, dtypes.float64]) 40 41ops.RegisterShape("ImageConnectedComponents")(common_shapes.call_cpp_shape_fn) 42ops.RegisterShape("ImageProjectiveTransform")(common_shapes.call_cpp_shape_fn) 43ops.RegisterShape("ImageProjectiveTransformV2")(common_shapes.call_cpp_shape_fn) 44 45 46# TODO(ringwalt): Support a "reshape" (name used by SciPy) or "expand" (name 47# used by PIL, maybe more readable) mode, which determines the correct 48# output_shape and translation for the transform. 49def rotate(images, angles, interpolation="NEAREST", name=None): 50 """Rotate image(s) counterclockwise by the passed angle(s) in radians. 51 52 Args: 53 images: A tensor of shape (num_images, num_rows, num_columns, num_channels) 54 (NHWC), (num_rows, num_columns, num_channels) (HWC), or 55 (num_rows, num_columns) (HW). The rank must be statically known (the 56 shape is not `TensorShape(None)`. 57 angles: A scalar angle to rotate all images by, or (if images has rank 4) 58 a vector of length num_images, with an angle for each image in the batch. 59 interpolation: Interpolation mode. Supported values: "NEAREST", "BILINEAR". 60 name: The name of the op. 61 62 Returns: 63 Image(s) with the same type and shape as `images`, rotated by the given 64 angle(s). Empty space due to the rotation will be filled with zeros. 65 66 Raises: 67 TypeError: If `image` is an invalid type. 68 """ 69 with ops.name_scope(name, "rotate"): 70 image_or_images = ops.convert_to_tensor(images) 71 if image_or_images.dtype.base_dtype not in _IMAGE_DTYPES: 72 raise TypeError("Invalid dtype %s." % image_or_images.dtype) 73 elif image_or_images.get_shape().ndims is None: 74 raise TypeError("image_or_images rank must be statically known") 75 elif len(image_or_images.get_shape()) == 2: 76 images = image_or_images[None, :, :, None] 77 elif len(image_or_images.get_shape()) == 3: 78 images = image_or_images[None, :, :, :] 79 elif len(image_or_images.get_shape()) == 4: 80 images = image_or_images 81 else: 82 raise TypeError("Images should have rank between 2 and 4.") 83 84 image_height = math_ops.cast(array_ops.shape(images)[1], 85 dtypes.float32)[None] 86 image_width = math_ops.cast(array_ops.shape(images)[2], 87 dtypes.float32)[None] 88 output = transform( 89 images, 90 angles_to_projective_transforms(angles, image_height, image_width), 91 interpolation=interpolation) 92 if image_or_images.get_shape().ndims is None: 93 raise TypeError("image_or_images rank must be statically known") 94 elif len(image_or_images.get_shape()) == 2: 95 return output[0, :, :, 0] 96 elif len(image_or_images.get_shape()) == 3: 97 return output[0, :, :, :] 98 else: 99 return output 100 101 102def translate(images, translations, interpolation="NEAREST", name=None): 103 """Translate image(s) by the passed vectors(s). 104 105 Args: 106 images: A tensor of shape (num_images, num_rows, num_columns, num_channels) 107 (NHWC), (num_rows, num_columns, num_channels) (HWC), or 108 (num_rows, num_columns) (HW). The rank must be statically known (the 109 shape is not `TensorShape(None)`. 110 translations: A vector representing [dx, dy] or (if images has rank 4) 111 a matrix of length num_images, with a [dx, dy] vector for each image in 112 the batch. 113 interpolation: Interpolation mode. Supported values: "NEAREST", "BILINEAR". 114 name: The name of the op. 115 116 Returns: 117 Image(s) with the same type and shape as `images`, translated by the given 118 vector(s). Empty space due to the translation will be filled with zeros. 119 120 Raises: 121 TypeError: If `image` is an invalid type. 122 """ 123 with ops.name_scope(name, "translate"): 124 return transform( 125 images, 126 translations_to_projective_transforms(translations), 127 interpolation=interpolation) 128 129 130def angles_to_projective_transforms(angles, 131 image_height, 132 image_width, 133 name=None): 134 """Returns projective transform(s) for the given angle(s). 135 136 Args: 137 angles: A scalar angle to rotate all images by, or (for batches of images) 138 a vector with an angle to rotate each image in the batch. The rank must 139 be statically known (the shape is not `TensorShape(None)`. 140 image_height: Height of the image(s) to be transformed. 141 image_width: Width of the image(s) to be transformed. 142 143 Returns: 144 A tensor of shape (num_images, 8). Projective transforms which can be given 145 to `tf.contrib.image.transform`. 146 """ 147 with ops.name_scope(name, "angles_to_projective_transforms"): 148 angle_or_angles = ops.convert_to_tensor( 149 angles, name="angles", dtype=dtypes.float32) 150 if len(angle_or_angles.get_shape()) == 0: # pylint: disable=g-explicit-length-test 151 angles = angle_or_angles[None] 152 elif len(angle_or_angles.get_shape()) == 1: 153 angles = angle_or_angles 154 else: 155 raise TypeError("Angles should have rank 0 or 1.") 156 x_offset = ((image_width - 1) - (math_ops.cos(angles) * 157 (image_width - 1) - math_ops.sin(angles) * 158 (image_height - 1))) / 2.0 159 y_offset = ((image_height - 1) - (math_ops.sin(angles) * 160 (image_width - 1) + math_ops.cos(angles) * 161 (image_height - 1))) / 2.0 162 num_angles = array_ops.shape(angles)[0] 163 return array_ops.concat( 164 values=[ 165 math_ops.cos(angles)[:, None], 166 -math_ops.sin(angles)[:, None], 167 x_offset[:, None], 168 math_ops.sin(angles)[:, None], 169 math_ops.cos(angles)[:, None], 170 y_offset[:, None], 171 array_ops.zeros((num_angles, 2), dtypes.float32), 172 ], 173 axis=1) 174 175 176def translations_to_projective_transforms(translations, name=None): 177 """Returns projective transform(s) for the given translation(s). 178 179 Args: 180 translations: A 2-element list representing [dx, dy] or a matrix of 181 2-element lists representing [dx, dy] to translate for each image 182 (for a batch of images). The rank must be statically known (the shape 183 is not `TensorShape(None)`. 184 name: The name of the op. 185 186 Returns: 187 A tensor of shape (num_images, 8) projective transforms which can be given 188 to `tf.contrib.image.transform`. 189 """ 190 with ops.name_scope(name, "translations_to_projective_transforms"): 191 translation_or_translations = ops.convert_to_tensor( 192 translations, name="translations", dtype=dtypes.float32) 193 if translation_or_translations.get_shape().ndims is None: 194 raise TypeError( 195 "translation_or_translations rank must be statically known") 196 elif len(translation_or_translations.get_shape()) == 1: 197 translations = translation_or_translations[None] 198 elif len(translation_or_translations.get_shape()) == 2: 199 translations = translation_or_translations 200 else: 201 raise TypeError("Translations should have rank 1 or 2.") 202 num_translations = array_ops.shape(translations)[0] 203 # The translation matrix looks like: 204 # [[1 0 -dx] 205 # [0 1 -dy] 206 # [0 0 1]] 207 # where the last entry is implicit. 208 # Translation matrices are always float32. 209 return array_ops.concat( 210 values=[ 211 array_ops.ones((num_translations, 1), dtypes.float32), 212 array_ops.zeros((num_translations, 1), dtypes.float32), 213 -translations[:, 0, None], 214 array_ops.zeros((num_translations, 1), dtypes.float32), 215 array_ops.ones((num_translations, 1), dtypes.float32), 216 -translations[:, 1, None], 217 array_ops.zeros((num_translations, 2), dtypes.float32), 218 ], 219 axis=1) 220 221 222def transform(images, 223 transforms, 224 interpolation="NEAREST", 225 output_shape=None, 226 name=None): 227 """Applies the given transform(s) to the image(s). 228 229 Args: 230 images: A tensor of shape (num_images, num_rows, num_columns, num_channels) 231 (NHWC), (num_rows, num_columns, num_channels) (HWC), or 232 (num_rows, num_columns) (HW). The rank must be statically known (the 233 shape is not `TensorShape(None)`. 234 transforms: Projective transform matrix/matrices. A vector of length 8 or 235 tensor of size N x 8. If one row of transforms is 236 [a0, a1, a2, b0, b1, b2, c0, c1], then it maps the *output* point 237 `(x, y)` to a transformed *input* point 238 `(x', y') = ((a0 x + a1 y + a2) / k, (b0 x + b1 y + b2) / k)`, 239 where `k = c0 x + c1 y + 1`. The transforms are *inverted* compared to 240 the transform mapping input points to output points. Note that gradients 241 are not backpropagated into transformation parameters. 242 interpolation: Interpolation mode. Supported values: "NEAREST", "BILINEAR". 243 output_shape: Output dimesion after the transform, [height, width]. 244 If None, output is the same size as input image. 245 246 name: The name of the op. 247 248 Returns: 249 Image(s) with the same type and shape as `images`, with the given 250 transform(s) applied. Transformed coordinates outside of the input image 251 will be filled with zeros. 252 253 Raises: 254 TypeError: If `image` is an invalid type. 255 ValueError: If output shape is not 1-D int32 Tensor. 256 """ 257 with ops.name_scope(name, "transform"): 258 image_or_images = ops.convert_to_tensor(images, name="images") 259 transform_or_transforms = ops.convert_to_tensor( 260 transforms, name="transforms", dtype=dtypes.float32) 261 if image_or_images.dtype.base_dtype not in _IMAGE_DTYPES: 262 raise TypeError("Invalid dtype %s." % image_or_images.dtype) 263 elif image_or_images.get_shape().ndims is None: 264 raise TypeError("image_or_images rank must be statically known") 265 elif len(image_or_images.get_shape()) == 2: 266 images = image_or_images[None, :, :, None] 267 elif len(image_or_images.get_shape()) == 3: 268 images = image_or_images[None, :, :, :] 269 elif len(image_or_images.get_shape()) == 4: 270 images = image_or_images 271 else: 272 raise TypeError("Images should have rank between 2 and 4.") 273 274 if output_shape is None: 275 output_shape = array_ops.shape(images)[1:3] 276 if not context.executing_eagerly(): 277 output_shape_value = tensor_util.constant_value(output_shape) 278 if output_shape_value is not None: 279 output_shape = output_shape_value 280 281 output_shape = ops.convert_to_tensor( 282 output_shape, dtypes.int32, name="output_shape") 283 284 if not output_shape.get_shape().is_compatible_with([2]): 285 raise ValueError("output_shape must be a 1-D Tensor of 2 elements: " 286 "new_height, new_width") 287 288 if len(transform_or_transforms.get_shape()) == 1: 289 transforms = transform_or_transforms[None] 290 elif transform_or_transforms.get_shape().ndims is None: 291 raise TypeError( 292 "transform_or_transforms rank must be statically known") 293 elif len(transform_or_transforms.get_shape()) == 2: 294 transforms = transform_or_transforms 295 else: 296 raise TypeError("Transforms should have rank 1 or 2.") 297 298 output = gen_image_ops.image_projective_transform_v2( 299 images, 300 output_shape=output_shape, 301 transforms=transforms, 302 interpolation=interpolation.upper()) 303 if len(image_or_images.get_shape()) == 2: 304 return output[0, :, :, 0] 305 elif len(image_or_images.get_shape()) == 3: 306 return output[0, :, :, :] 307 else: 308 return output 309 310 311def compose_transforms(*transforms): 312 """Composes the transforms tensors. 313 314 Args: 315 *transforms: List of image projective transforms to be composed. Each 316 transform is length 8 (single transform) or shape (N, 8) (batched 317 transforms). The shapes of all inputs must be equal, and at least one 318 input must be given. 319 320 Returns: 321 A composed transform tensor. When passed to `tf.contrib.image.transform`, 322 equivalent to applying each of the given transforms to the image in 323 order. 324 """ 325 assert transforms, "transforms cannot be empty" 326 with ops.name_scope("compose_transforms"): 327 composed = flat_transforms_to_matrices(transforms[0]) 328 for tr in transforms[1:]: 329 # Multiply batches of matrices. 330 composed = math_ops.matmul(composed, flat_transforms_to_matrices(tr)) 331 return matrices_to_flat_transforms(composed) 332 333 334def flat_transforms_to_matrices(transforms): 335 """Converts `tf.contrib.image` projective transforms to affine matrices. 336 337 Note that the output matrices map output coordinates to input coordinates. For 338 the forward transformation matrix, call `tf.linalg.inv` on the result. 339 340 Args: 341 transforms: Vector of length 8, or batches of transforms with shape 342 `(N, 8)`. 343 344 Returns: 345 3D tensor of matrices with shape `(N, 3, 3)`. The output matrices map the 346 *output coordinates* (in homogeneous coordinates) of each transform to the 347 corresponding *input coordinates*. 348 349 Raises: 350 ValueError: If `transforms` have an invalid shape. 351 """ 352 with ops.name_scope("flat_transforms_to_matrices"): 353 transforms = ops.convert_to_tensor(transforms, name="transforms") 354 if transforms.shape.ndims not in (1, 2): 355 raise ValueError("Transforms should be 1D or 2D, got: %s" % transforms) 356 # Make the transform(s) 2D in case the input is a single transform. 357 transforms = array_ops.reshape(transforms, constant_op.constant([-1, 8])) 358 num_transforms = array_ops.shape(transforms)[0] 359 # Add a column of ones for the implicit last entry in the matrix. 360 return array_ops.reshape( 361 array_ops.concat( 362 [transforms, array_ops.ones([num_transforms, 1])], axis=1), 363 constant_op.constant([-1, 3, 3])) 364 365 366def matrices_to_flat_transforms(transform_matrices): 367 """Converts affine matrices to `tf.contrib.image` projective transforms. 368 369 Note that we expect matrices that map output coordinates to input coordinates. 370 To convert forward transformation matrices, call `tf.linalg.inv` on the 371 matrices and use the result here. 372 373 Args: 374 transform_matrices: One or more affine transformation matrices, for the 375 reverse transformation in homogeneous coordinates. Shape `(3, 3)` or 376 `(N, 3, 3)`. 377 378 Returns: 379 2D tensor of flat transforms with shape `(N, 8)`, which may be passed into 380 `tf.contrib.image.transform`. 381 382 Raises: 383 ValueError: If `transform_matrices` have an invalid shape. 384 """ 385 with ops.name_scope("matrices_to_flat_transforms"): 386 transform_matrices = ops.convert_to_tensor( 387 transform_matrices, name="transform_matrices") 388 if transform_matrices.shape.ndims not in (2, 3): 389 raise ValueError( 390 "Matrices should be 2D or 3D, got: %s" % transform_matrices) 391 # Flatten each matrix. 392 transforms = array_ops.reshape(transform_matrices, 393 constant_op.constant([-1, 9])) 394 # Divide each matrix by the last entry (normally 1). 395 transforms /= transforms[:, 8:9] 396 return transforms[:, :8] 397 398 399@ops.RegisterGradient("ImageProjectiveTransformV2") 400def _image_projective_transform_grad(op, grad): 401 """Computes the gradient for ImageProjectiveTransform.""" 402 images = op.inputs[0] 403 transforms = op.inputs[1] 404 interpolation = op.get_attr("interpolation") 405 406 image_or_images = ops.convert_to_tensor(images, name="images") 407 transform_or_transforms = ops.convert_to_tensor( 408 transforms, name="transforms", dtype=dtypes.float32) 409 410 if image_or_images.dtype.base_dtype not in _IMAGE_DTYPES: 411 raise TypeError("Invalid dtype %s." % image_or_images.dtype) 412 if len(transform_or_transforms.get_shape()) == 1: 413 transforms = transform_or_transforms[None] 414 elif len(transform_or_transforms.get_shape()) == 2: 415 transforms = transform_or_transforms 416 else: 417 raise TypeError("Transforms should have rank 1 or 2.") 418 419 # Invert transformations 420 transforms = flat_transforms_to_matrices(transforms=transforms) 421 inverse = linalg_ops.matrix_inverse(transforms) 422 transforms = matrices_to_flat_transforms(inverse) 423 output = gen_image_ops.image_projective_transform_v2( 424 images=grad, 425 transforms=transforms, 426 output_shape=array_ops.shape(image_or_images)[1:3], 427 interpolation=interpolation) 428 return [output, None, None] 429 430 431def bipartite_match(distance_mat, 432 num_valid_rows, 433 top_k=-1, 434 name="bipartite_match"): 435 """Find bipartite matching based on a given distance matrix. 436 437 A greedy bi-partite matching algorithm is used to obtain the matching with 438 the (greedy) minimum distance. 439 440 Args: 441 distance_mat: A 2-D float tensor of shape `[num_rows, num_columns]`. It is a 442 pair-wise distance matrix between the entities represented by each row and 443 each column. It is an asymmetric matrix. The smaller the distance is, the 444 more similar the pairs are. The bipartite matching is to minimize the 445 distances. 446 num_valid_rows: A scalar or a 1-D tensor with one element describing the 447 number of valid rows of distance_mat to consider for the bipartite 448 matching. If set to be negative, then all rows from `distance_mat` are 449 used. 450 top_k: A scalar that specifies the number of top-k matches to retrieve. 451 If set to be negative, then is set according to the maximum number of 452 matches from `distance_mat`. 453 name: The name of the op. 454 455 Returns: 456 row_to_col_match_indices: A vector of length num_rows, which is the number 457 of rows of the input `distance_matrix`. If `row_to_col_match_indices[i]` 458 is not -1, row i is matched to column `row_to_col_match_indices[i]`. 459 col_to_row_match_indices: A vector of length num_columns, which is the 460 number of columns of the input distance matrix. 461 If `col_to_row_match_indices[j]` is not -1, column j is matched to row 462 `col_to_row_match_indices[j]`. 463 """ 464 result = gen_image_ops.bipartite_match( 465 distance_mat, num_valid_rows, top_k, name=name) 466 return result 467 468 469def connected_components(images): 470 """Labels the connected components in a batch of images. 471 472 A component is a set of pixels in a single input image, which are all adjacent 473 and all have the same non-zero value. The components using a squared 474 connectivity of one (all True entries are joined with their neighbors above, 475 below, left, and right). Components across all images have consecutive ids 1 476 through n. Components are labeled according to the first pixel of the 477 component appearing in row-major order (lexicographic order by 478 image_index_in_batch, row, col). Zero entries all have an output id of 0. 479 480 This op is equivalent with `scipy.ndimage.measurements.label` on a 2D array 481 with the default structuring element (which is the connectivity used here). 482 483 Args: 484 images: A 2D (H, W) or 3D (N, H, W) Tensor of boolean image(s). 485 486 Returns: 487 Components with the same shape as `images`. False entries in `images` have 488 value 0, and all True entries map to a component id > 0. 489 490 Raises: 491 TypeError: if `images` is not 2D or 3D. 492 """ 493 with ops.name_scope("connected_components"): 494 image_or_images = ops.convert_to_tensor(images, name="images") 495 if len(image_or_images.get_shape()) == 2: 496 images = image_or_images[None, :, :] 497 elif len(image_or_images.get_shape()) == 3: 498 images = image_or_images 499 else: 500 raise TypeError( 501 "images should have rank 2 (HW) or 3 (NHW). Static shape is %s" % 502 image_or_images.get_shape()) 503 components = gen_image_ops.image_connected_components(images) 504 505 # TODO(ringwalt): Component id renaming should be done in the op, to avoid 506 # constructing multiple additional large tensors. 507 components_flat = array_ops.reshape(components, [-1]) 508 unique_ids, id_index = array_ops.unique(components_flat) 509 id_is_zero = array_ops.where(math_ops.equal(unique_ids, 0))[:, 0] 510 # Map each nonzero id to consecutive values. 511 nonzero_consecutive_ids = math_ops.range( 512 array_ops.shape(unique_ids)[0] - array_ops.shape(id_is_zero)[0]) + 1 513 514 def no_zero(): 515 # No need to insert a zero into the ids. 516 return nonzero_consecutive_ids 517 518 def has_zero(): 519 # Insert a zero in the consecutive ids where zero appears in unique_ids. 520 # id_is_zero has length 1. 521 zero_id_ind = math_ops.cast(id_is_zero[0], dtypes.int32) 522 ids_before = nonzero_consecutive_ids[:zero_id_ind] 523 ids_after = nonzero_consecutive_ids[zero_id_ind:] 524 return array_ops.concat([ids_before, [0], ids_after], axis=0) 525 526 new_ids = control_flow_ops.cond( 527 math_ops.equal(array_ops.shape(id_is_zero)[0], 0), no_zero, has_zero) 528 components = array_ops.reshape( 529 array_ops.gather(new_ids, id_index), array_ops.shape(components)) 530 if len(image_or_images.get_shape()) == 2: 531 return components[0, :, :] 532 else: 533 return components 534 535 536ops.NotDifferentiable("BipartiteMatch") 537ops.NotDifferentiable("ImageConnectedComponents") 538