• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Image operations for RaggedTensors."""
16
17from tensorflow.python.framework import dtypes
18from tensorflow.python.framework import ops
19from tensorflow.python.framework import tensor_shape
20from tensorflow.python.framework import tensor_spec
21from tensorflow.python.framework import tensor_util
22from tensorflow.python.ops import array_ops
23from tensorflow.python.ops import control_flow_ops
24from tensorflow.python.ops import image_ops
25from tensorflow.python.ops import map_fn
26from tensorflow.python.ops import math_ops
27from tensorflow.python.ops.ragged import ragged_tensor
28from tensorflow.python.util import dispatch
29
30
31@dispatch.dispatch_for_api(image_ops.resize_images_v2)
32def resize_images_v2(images: ragged_tensor.RaggedTensor,
33                     size,
34                     method=image_ops.ResizeMethod.BILINEAR,
35                     preserve_aspect_ratio=False,
36                     antialias=False,
37                     name=None):
38  """RaggedTensor dispatcher for tf.image.resize (tf-v2)."""
39  with ops.name_scope(name, "RaggedResizeImages", [images, size]):
40    return _resize_images(
41        image_ops.resize_images_v2,
42        images,
43        size,
44        method=method,
45        preserve_aspect_ratio=preserve_aspect_ratio,
46        antialias=antialias)
47
48
49@dispatch.dispatch_for_api(image_ops.resize_images)
50def resize_images_v1(images: ragged_tensor.RaggedTensor,
51                     size,
52                     method=image_ops.ResizeMethodV1.BILINEAR,
53                     align_corners=False,
54                     preserve_aspect_ratio=False,
55                     name=None):
56  """RaggedTensor dispatcher for tf.image.resize (tf-v1)."""
57  with ops.name_scope(name, "RaggedResizeImages", [images, size]):
58    return _resize_images(
59        image_ops.resize_images,
60        images,
61        size,
62        method=method,
63        preserve_aspect_ratio=preserve_aspect_ratio,
64        align_corners=align_corners)
65
66
67def _resize_images(resize_op, images, size, **kwargs):
68  """RaggedTensor dispatcher for tf.image.resize."""
69  if images.shape.rank != 4:
70    raise ValueError(
71        "tf.image.resize: images.shape.rank must be 4 if images is ragged.")
72
73  # Determine the output shape (excluding the batch dimension).
74  static_batch_size = tensor_shape.dimension_value(images.shape[0])
75  size = ops.convert_to_tensor(size, dtypes.int32, "size")
76  size_as_shape = tensor_util.constant_value_as_shape(size).with_rank(2)
77  out_shape = size_as_shape + images.shape[-1:]
78  out_spec = tensor_spec.TensorSpec(out_shape, dtypes.float32)
79
80  def resize_one(image):
81    if isinstance(image, ragged_tensor.RaggedTensor):
82      image = image.to_tensor()
83    return resize_op(image, size, **kwargs)
84
85  def resize_with_map():
86    return map_fn.map_fn_v2(resize_one, images, fn_output_signature=out_spec)
87
88  def empty_result():
89    channels = array_ops.shape(images.flat_values)[-1:]
90    return array_ops.zeros(array_ops.concat([[0], size, channels], axis=0))
91
92  if static_batch_size == 0:
93    return empty_result()
94  elif static_batch_size is not None:
95    return resize_with_map()
96  else:
97    empty_batch = math_ops.equal(images.nrows(), 0)
98    return control_flow_ops.cond(empty_batch, empty_result, resize_with_map)
99