1# Copyright 2021 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Image operations for RaggedTensors.""" 16 17from tensorflow.python.framework import dtypes 18from tensorflow.python.framework import ops 19from tensorflow.python.framework import tensor_shape 20from tensorflow.python.framework import tensor_spec 21from tensorflow.python.framework import tensor_util 22from tensorflow.python.ops import array_ops 23from tensorflow.python.ops import control_flow_ops 24from tensorflow.python.ops import image_ops 25from tensorflow.python.ops import map_fn 26from tensorflow.python.ops import math_ops 27from tensorflow.python.ops.ragged import ragged_tensor 28from tensorflow.python.util import dispatch 29 30 31@dispatch.dispatch_for_api(image_ops.resize_images_v2) 32def resize_images_v2(images: ragged_tensor.RaggedTensor, 33 size, 34 method=image_ops.ResizeMethod.BILINEAR, 35 preserve_aspect_ratio=False, 36 antialias=False, 37 name=None): 38 """RaggedTensor dispatcher for tf.image.resize (tf-v2).""" 39 with ops.name_scope(name, "RaggedResizeImages", [images, size]): 40 return _resize_images( 41 image_ops.resize_images_v2, 42 images, 43 size, 44 method=method, 45 preserve_aspect_ratio=preserve_aspect_ratio, 46 antialias=antialias) 47 48 49@dispatch.dispatch_for_api(image_ops.resize_images) 50def resize_images_v1(images: ragged_tensor.RaggedTensor, 51 size, 52 method=image_ops.ResizeMethodV1.BILINEAR, 53 align_corners=False, 54 preserve_aspect_ratio=False, 55 name=None): 56 """RaggedTensor dispatcher for tf.image.resize (tf-v1).""" 57 with ops.name_scope(name, "RaggedResizeImages", [images, size]): 58 return _resize_images( 59 image_ops.resize_images, 60 images, 61 size, 62 method=method, 63 preserve_aspect_ratio=preserve_aspect_ratio, 64 align_corners=align_corners) 65 66 67def _resize_images(resize_op, images, size, **kwargs): 68 """RaggedTensor dispatcher for tf.image.resize.""" 69 if images.shape.rank != 4: 70 raise ValueError( 71 "tf.image.resize: images.shape.rank must be 4 if images is ragged.") 72 73 # Determine the output shape (excluding the batch dimension). 74 static_batch_size = tensor_shape.dimension_value(images.shape[0]) 75 size = ops.convert_to_tensor(size, dtypes.int32, "size") 76 size_as_shape = tensor_util.constant_value_as_shape(size).with_rank(2) 77 out_shape = size_as_shape + images.shape[-1:] 78 out_spec = tensor_spec.TensorSpec(out_shape, dtypes.float32) 79 80 def resize_one(image): 81 if isinstance(image, ragged_tensor.RaggedTensor): 82 image = image.to_tensor() 83 return resize_op(image, size, **kwargs) 84 85 def resize_with_map(): 86 return map_fn.map_fn_v2(resize_one, images, fn_output_signature=out_spec) 87 88 def empty_result(): 89 channels = array_ops.shape(images.flat_values)[-1:] 90 return array_ops.zeros(array_ops.concat([[0], size, channels], axis=0)) 91 92 if static_batch_size == 0: 93 return empty_result() 94 elif static_batch_size is not None: 95 return resize_with_map() 96 else: 97 empty_batch = math_ops.equal(images.nrows(), 0) 98 return control_flow_ops.cond(empty_batch, empty_result, resize_with_map) 99