1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Concat and stack operations for RaggedTensors.""" 16 17import typing 18 19from tensorflow.python.framework import ops 20from tensorflow.python.framework import tensor_shape 21from tensorflow.python.ops import array_ops 22from tensorflow.python.ops import check_ops 23from tensorflow.python.ops import math_ops 24from tensorflow.python.ops.ragged import ragged_gather_ops 25from tensorflow.python.ops.ragged import ragged_tensor 26from tensorflow.python.ops.ragged import ragged_util 27from tensorflow.python.util import dispatch 28from tensorflow.python.util.tf_export import tf_export 29 30 31@dispatch.dispatch_for_api(array_ops.concat) 32def concat(values: typing.List[ragged_tensor.RaggedOrDense], axis, name=None): 33 """Concatenates potentially ragged tensors along one dimension. 34 35 Given a list of tensors with the same rank `K` (`K >= axis`), returns a 36 rank-`K` `RaggedTensor` `result` such that `result[i0...iaxis]` is the 37 concatenation of `[rt[i0...iaxis] for rt in values]`. 38 39 Args: 40 values: A list of potentially ragged tensors. May not be empty. All 41 `values` must have the same rank and the same dtype; but unlike 42 `tf.concat`, they can have arbitrary shapes. 43 axis: A python integer, indicating the dimension along which to concatenate. 44 (Note: Unlike `tf.concat`, the `axis` parameter must be statically known.) 45 Negative values are supported only if the rank of at least one 46 `values` value is statically known. 47 name: A name prefix for the returned tensor (optional). 48 49 Returns: 50 A `RaggedTensor` with rank `K`. 51 `result.ragged_rank=max(axis, max(rt.ragged_rank for rt in values]))`. 52 53 Raises: 54 ValueError: If `values` is empty, if `axis` is out of bounds or if 55 the input tensors have different ranks. 56 57 #### Example: 58 59 >>> t1 = tf.ragged.constant([[1, 2], [3, 4, 5]]) 60 >>> t2 = tf.ragged.constant([[6], [7, 8, 9]]) 61 >>> tf.concat([t1, t2], axis=0) 62 <tf.RaggedTensor [[1, 2], [3, 4, 5], [6], [7, 8, 9]]> 63 >>> tf.concat([t1, t2], axis=1) 64 <tf.RaggedTensor [[1, 2, 6], [3, 4, 5, 7, 8, 9]]> 65 """ 66 if not isinstance(values, (list, tuple)): 67 values = [values] 68 with ops.name_scope(name, 'RaggedConcat', values): 69 return _ragged_stack_concat_helper(values, axis, stack_values=False) 70 71 72@tf_export('ragged.stack') 73@dispatch.add_dispatch_support 74@dispatch.dispatch_for_api(array_ops.stack) 75def stack(values: typing.List[ragged_tensor.RaggedOrDense], 76 axis=0, 77 name=None): 78 """Stacks a list of rank-`R` tensors into one rank-`(R+1)` `RaggedTensor`. 79 80 Given a list of tensors or ragged tensors with the same rank `R` 81 (`R >= axis`), returns a rank-`R+1` `RaggedTensor` `result` such that 82 `result[i0...iaxis]` is `[value[i0...iaxis] for value in values]`. 83 84 #### Examples: 85 86 >>> # Stacking two ragged tensors. 87 >>> t1 = tf.ragged.constant([[1, 2], [3, 4, 5]]) 88 >>> t2 = tf.ragged.constant([[6], [7, 8, 9]]) 89 >>> tf.ragged.stack([t1, t2], axis=0) 90 <tf.RaggedTensor [[[1, 2], [3, 4, 5]], [[6], [7, 8, 9]]]> 91 >>> tf.ragged.stack([t1, t2], axis=1) 92 <tf.RaggedTensor [[[1, 2], [6]], [[3, 4, 5], [7, 8, 9]]]> 93 94 >>> # Stacking two dense tensors with different sizes. 95 >>> t3 = tf.constant([[1, 2, 3], [4, 5, 6]]) 96 >>> t4 = tf.constant([[5], [6], [7]]) 97 >>> tf.ragged.stack([t3, t4], axis=0) 98 <tf.RaggedTensor [[[1, 2, 3], [4, 5, 6]], [[5], [6], [7]]]> 99 100 Args: 101 values: A list of `tf.Tensor` or `tf.RaggedTensor`. May not be empty. All 102 `values` must have the same rank and the same dtype; but unlike 103 `tf.stack`, they can have arbitrary dimension sizes. 104 axis: A python integer, indicating the dimension along which to stack. 105 (Note: Unlike `tf.stack`, the `axis` parameter must be statically known.) 106 Negative values are supported only if the rank of at least one 107 `values` value is statically known. 108 name: A name prefix for the returned tensor (optional). 109 110 Returns: 111 A `RaggedTensor` with rank `R+1` (if `R>0`). 112 If `R==0`, then the result will be returned as a 1D `Tensor`, since 113 `RaggedTensor` can only be used when `rank>1`. 114 `result.ragged_rank=1+max(axis, max(rt.ragged_rank for rt in values]))`. 115 116 Raises: 117 ValueError: If `values` is empty, if `axis` is out of bounds or if 118 the input tensors have different ranks. 119 """ 120 if not isinstance(values, (list, tuple)): 121 values = [values] 122 with ops.name_scope(name, 'RaggedConcat', values): 123 return _ragged_stack_concat_helper(values, axis, stack_values=True) 124 125 126def _ragged_stack_concat_helper(rt_inputs, axis, stack_values): 127 """Helper function to concatenate or stack ragged tensors. 128 129 Args: 130 rt_inputs: A list of RaggedTensors or Tensors to combine. 131 axis: The axis along which to concatenate or stack. 132 stack_values: A boolean -- if true, then stack values; otherwise, 133 concatenate them. 134 135 Returns: 136 A RaggedTensor. 137 Raises: 138 ValueError: If rt_inputs is empty, or if axis is out of range. 139 """ 140 # Validate parameters. 141 if not rt_inputs: 142 raise ValueError('rt_inputs may not be empty.') 143 144 # Convert input tensors. 145 rt_inputs = [ 146 ragged_tensor.convert_to_tensor_or_ragged_tensor( 147 rt_input, name='rt_input') for rt_input in rt_inputs 148 ] 149 row_splits_dtype, rt_inputs = ragged_tensor.match_row_splits_dtypes( 150 *rt_inputs, return_dtype=True) 151 rt_inputs = list(rt_inputs) 152 153 # Special case: if there's only one input, then return it as-is. 154 if len(rt_inputs) == 1 and not stack_values: 155 return rt_inputs[0] 156 157 # Check the rank (number of dimensions) of the input tensors. 158 ndims = None 159 for rt in rt_inputs: 160 if ndims is None: 161 ndims = rt.shape.ndims 162 else: 163 rt.shape.assert_has_rank(ndims) 164 165 out_ndims = ndims if (ndims is None or not stack_values) else ndims + 1 166 axis = array_ops.get_positive_axis(axis, out_ndims) 167 168 if stack_values and ndims == 1 and axis == 0: 169 return ragged_tensor.RaggedTensor.from_row_lengths( 170 values=array_ops.concat(rt_inputs, axis=0), 171 row_lengths=array_ops.concat([array_ops.shape(r) for r in rt_inputs], 172 axis=0)) 173 174 # If all the inputs are Tensors, and we're combining the final dimension, 175 # then we can delegate to the tf.stack/tf.concat operation, and return a 176 # Tensor. 177 if all(not ragged_tensor.is_ragged(rt) for rt in rt_inputs): 178 if ndims is not None and (axis == out_ndims - 1 or axis == ndims - 1): 179 if stack_values: 180 return array_ops.stack(rt_inputs, axis) 181 else: 182 return array_ops.concat(rt_inputs, axis) 183 184 # Convert any Tensor inputs to RaggedTensors. This makes it 185 # possible to concatenate Tensors and RaggedTensors together. 186 for i in range(len(rt_inputs)): 187 if not ragged_tensor.is_ragged(rt_inputs[i]): 188 rt_inputs[i] = ragged_tensor.RaggedTensor.from_tensor( 189 rt_inputs[i], ragged_rank=1, row_splits_dtype=row_splits_dtype) 190 191 # Convert the input tensors to all have the same ragged_rank. 192 ragged_rank = max(max(rt.ragged_rank for rt in rt_inputs), 1) 193 rt_inputs = [_increase_ragged_rank_to(rt, ragged_rank, row_splits_dtype) 194 for rt in rt_inputs] 195 196 if axis == 0: 197 return _ragged_stack_concat_axis_0(rt_inputs, stack_values) 198 elif axis == 1: 199 return _ragged_stack_concat_axis_1(rt_inputs, stack_values) 200 else: # axis > 1: recurse. 201 values = [rt.values for rt in rt_inputs] 202 splits = [[rt_input.row_splits] for rt_input in rt_inputs] 203 with ops.control_dependencies(ragged_util.assert_splits_match(splits)): 204 return ragged_tensor.RaggedTensor.from_row_splits( 205 _ragged_stack_concat_helper(values, axis - 1, stack_values), 206 splits[0][0], validate=False) 207 208 209def _ragged_stack_concat_axis_0(rt_inputs, stack_values): 210 """Helper function to concatenate or stack ragged tensors along axis 0. 211 212 Args: 213 rt_inputs: A list of RaggedTensors, all with the same rank and ragged_rank. 214 stack_values: Boolean. If true, then stack values; otherwise, concatenate 215 them. 216 217 Returns: 218 A RaggedTensor. 219 """ 220 # Concatenate the inner values together. 221 flat_values = [rt.flat_values for rt in rt_inputs] 222 concatenated_flat_values = array_ops.concat(flat_values, axis=0) 223 224 # Concatenate the splits together for each ragged dimension (adjusting 225 # split offsets as necessary). 226 nested_splits = [rt.nested_row_splits for rt in rt_inputs] 227 ragged_rank = rt_inputs[0].ragged_rank 228 concatenated_nested_splits = [ 229 _concat_ragged_splits([ns[dim] 230 for ns in nested_splits]) 231 for dim in range(ragged_rank) 232 ] 233 234 # If we are performing a stack operation, then add another splits. 235 if stack_values: 236 stack_lengths = array_ops.stack([rt.nrows() for rt in rt_inputs]) 237 stack_splits = ragged_util.lengths_to_splits(stack_lengths) 238 concatenated_nested_splits.insert(0, stack_splits) 239 240 return ragged_tensor.RaggedTensor.from_nested_row_splits( 241 concatenated_flat_values, concatenated_nested_splits, validate=False) 242 243 244def _ragged_stack_concat_axis_1(rt_inputs, stack_values): 245 """Helper function to concatenate or stack ragged tensors along axis 1. 246 247 Args: 248 rt_inputs: A list of RaggedTensors, all with the same rank and ragged_rank. 249 stack_values: Boolean. If true, then stack values; otherwise, concatenate 250 them. 251 252 Returns: 253 A RaggedTensor. 254 """ 255 num_inputs = len(rt_inputs) 256 257 rt_nrows = rt_inputs[0].nrows() 258 nrows_msg = 'Input tensors have incompatible shapes.' 259 nrows_checks = [ 260 check_ops.assert_equal(rt.nrows(), rt_nrows, message=nrows_msg) 261 for rt in rt_inputs[1:] 262 ] 263 264 with ops.control_dependencies(nrows_checks): 265 # Concatenate the inputs together to put them in a single ragged tensor. 266 concatenated_rt = _ragged_stack_concat_axis_0(rt_inputs, stack_values=False) 267 268 # Use ragged.gather to permute the rows of concatenated_rt. In particular, 269 # permuted_rt = [rt_inputs[0][0], ..., rt_inputs[N][0], 270 # rt_inputs[0][1], ..., rt_inputs[N][1], 271 # ..., 272 # rt_inputs[0][M], ..., rt_input[N][M]] 273 # where `N=num_inputs-1` and `M=rt_nrows-1`. 274 row_indices = math_ops.range(rt_nrows * num_inputs) 275 row_index_matrix = array_ops.reshape(row_indices, [num_inputs, -1]) 276 transposed_row_index_matrix = array_ops.transpose(row_index_matrix) 277 row_permutation = array_ops.reshape(transposed_row_index_matrix, [-1]) 278 permuted_rt = ragged_gather_ops.gather(concatenated_rt, row_permutation) 279 280 if stack_values: 281 # Add a new splits tensor to group together the values. 282 stack_splits = math_ops.range(0, rt_nrows * num_inputs + 1, num_inputs) 283 _copy_row_shape(rt_inputs, stack_splits) 284 return ragged_tensor.RaggedTensor.from_row_splits( 285 permuted_rt, stack_splits, validate=False) 286 else: 287 # Merge together adjacent rows by dropping the row-split indices that 288 # separate them. 289 concat_splits = permuted_rt.row_splits[::num_inputs] 290 _copy_row_shape(rt_inputs, concat_splits) 291 return ragged_tensor.RaggedTensor.from_row_splits( 292 permuted_rt.values, concat_splits, validate=False) 293 294 295def _copy_row_shape(rt_inputs, splits): 296 """Sets splits.shape to [rt[shape[0]+1] for each rt in rt_inputs.""" 297 for rt in rt_inputs: 298 if rt.shape[0] is not None: 299 splits.set_shape(tensor_shape.TensorShape(rt.shape[0] + 1)) 300 301 302def _increase_ragged_rank_to(rt_input, ragged_rank, row_splits_dtype): 303 """Adds ragged dimensions to `rt_input` so it has the desired ragged rank.""" 304 if ragged_rank > 0: 305 if not ragged_tensor.is_ragged(rt_input): 306 rt_input = ragged_tensor.RaggedTensor.from_tensor( 307 rt_input, row_splits_dtype=row_splits_dtype) 308 if rt_input.ragged_rank < ragged_rank: 309 rt_input = rt_input.with_values( 310 _increase_ragged_rank_to(rt_input.values, ragged_rank - 1, 311 row_splits_dtype)) 312 return rt_input 313 314 315def _concat_ragged_splits(splits_list): 316 """Concatenates a list of RaggedTensor splits to form a single splits.""" 317 pieces = [splits_list[0]] 318 splits_offset = splits_list[0][-1] 319 for splits in splits_list[1:]: 320 pieces.append(splits[1:] + splits_offset) 321 splits_offset += splits[-1] 322 return array_ops.concat(pieces, axis=0) 323