• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Concat and stack operations for RaggedTensors."""
16
17import typing
18
19from tensorflow.python.framework import ops
20from tensorflow.python.framework import tensor_shape
21from tensorflow.python.ops import array_ops
22from tensorflow.python.ops import check_ops
23from tensorflow.python.ops import math_ops
24from tensorflow.python.ops.ragged import ragged_gather_ops
25from tensorflow.python.ops.ragged import ragged_tensor
26from tensorflow.python.ops.ragged import ragged_util
27from tensorflow.python.util import dispatch
28from tensorflow.python.util.tf_export import tf_export
29
30
31@dispatch.dispatch_for_api(array_ops.concat)
32def concat(values: typing.List[ragged_tensor.RaggedOrDense], axis, name=None):
33  """Concatenates potentially ragged tensors along one dimension.
34
35  Given a list of tensors with the same rank `K` (`K >= axis`), returns a
36  rank-`K` `RaggedTensor` `result` such that `result[i0...iaxis]` is the
37  concatenation of `[rt[i0...iaxis] for rt in values]`.
38
39  Args:
40    values: A list of potentially ragged tensors.  May not be empty. All
41      `values` must have the same rank and the same dtype; but unlike
42      `tf.concat`, they can have arbitrary shapes.
43    axis: A python integer, indicating the dimension along which to concatenate.
44      (Note: Unlike `tf.concat`, the `axis` parameter must be statically known.)
45        Negative values are supported only if the rank of at least one
46        `values` value is statically known.
47    name: A name prefix for the returned tensor (optional).
48
49  Returns:
50    A `RaggedTensor` with rank `K`.
51    `result.ragged_rank=max(axis, max(rt.ragged_rank for rt in values]))`.
52
53  Raises:
54    ValueError: If `values` is empty, if `axis` is out of bounds or if
55      the input tensors have different ranks.
56
57  #### Example:
58
59  >>> t1 = tf.ragged.constant([[1, 2], [3, 4, 5]])
60  >>> t2 = tf.ragged.constant([[6], [7, 8, 9]])
61  >>> tf.concat([t1, t2], axis=0)
62  <tf.RaggedTensor [[1, 2], [3, 4, 5], [6], [7, 8, 9]]>
63  >>> tf.concat([t1, t2], axis=1)
64  <tf.RaggedTensor [[1, 2, 6], [3, 4, 5, 7, 8, 9]]>
65  """
66  if not isinstance(values, (list, tuple)):
67    values = [values]
68  with ops.name_scope(name, 'RaggedConcat', values):
69    return _ragged_stack_concat_helper(values, axis, stack_values=False)
70
71
72@tf_export('ragged.stack')
73@dispatch.add_dispatch_support
74@dispatch.dispatch_for_api(array_ops.stack)
75def stack(values: typing.List[ragged_tensor.RaggedOrDense],
76          axis=0,
77          name=None):
78  """Stacks a list of rank-`R` tensors into one rank-`(R+1)` `RaggedTensor`.
79
80  Given a list of tensors or ragged tensors with the same rank `R`
81  (`R >= axis`), returns a rank-`R+1` `RaggedTensor` `result` such that
82  `result[i0...iaxis]` is `[value[i0...iaxis] for value in values]`.
83
84  #### Examples:
85
86  >>> # Stacking two ragged tensors.
87  >>> t1 = tf.ragged.constant([[1, 2], [3, 4, 5]])
88  >>> t2 = tf.ragged.constant([[6], [7, 8, 9]])
89  >>> tf.ragged.stack([t1, t2], axis=0)
90  <tf.RaggedTensor [[[1, 2], [3, 4, 5]], [[6], [7, 8, 9]]]>
91  >>> tf.ragged.stack([t1, t2], axis=1)
92  <tf.RaggedTensor [[[1, 2], [6]], [[3, 4, 5], [7, 8, 9]]]>
93
94  >>> # Stacking two dense tensors with different sizes.
95  >>> t3 = tf.constant([[1, 2, 3], [4, 5, 6]])
96  >>> t4 = tf.constant([[5], [6], [7]])
97  >>> tf.ragged.stack([t3, t4], axis=0)
98  <tf.RaggedTensor [[[1, 2, 3], [4, 5, 6]], [[5], [6], [7]]]>
99
100  Args:
101    values: A list of `tf.Tensor` or `tf.RaggedTensor`.  May not be empty. All
102      `values` must have the same rank and the same dtype; but unlike
103      `tf.stack`, they can have arbitrary dimension sizes.
104    axis: A python integer, indicating the dimension along which to stack.
105      (Note: Unlike `tf.stack`, the `axis` parameter must be statically known.)
106      Negative values are supported only if the rank of at least one
107      `values` value is statically known.
108    name: A name prefix for the returned tensor (optional).
109
110  Returns:
111    A `RaggedTensor` with rank `R+1` (if `R>0`).
112    If `R==0`, then the result will be returned as a 1D `Tensor`, since
113    `RaggedTensor` can only be used when `rank>1`.
114    `result.ragged_rank=1+max(axis, max(rt.ragged_rank for rt in values]))`.
115
116  Raises:
117    ValueError: If `values` is empty, if `axis` is out of bounds or if
118      the input tensors have different ranks.
119  """
120  if not isinstance(values, (list, tuple)):
121    values = [values]
122  with ops.name_scope(name, 'RaggedConcat', values):
123    return _ragged_stack_concat_helper(values, axis, stack_values=True)
124
125
126def _ragged_stack_concat_helper(rt_inputs, axis, stack_values):
127  """Helper function to concatenate or stack ragged tensors.
128
129  Args:
130    rt_inputs: A list of RaggedTensors or Tensors to combine.
131    axis: The axis along which to concatenate or stack.
132    stack_values: A boolean -- if true, then stack values; otherwise,
133      concatenate them.
134
135  Returns:
136    A RaggedTensor.
137  Raises:
138    ValueError: If rt_inputs is empty, or if axis is out of range.
139  """
140  # Validate parameters.
141  if not rt_inputs:
142    raise ValueError('rt_inputs may not be empty.')
143
144  # Convert input tensors.
145  rt_inputs = [
146      ragged_tensor.convert_to_tensor_or_ragged_tensor(
147          rt_input, name='rt_input') for rt_input in rt_inputs
148  ]
149  row_splits_dtype, rt_inputs = ragged_tensor.match_row_splits_dtypes(
150      *rt_inputs, return_dtype=True)
151  rt_inputs = list(rt_inputs)
152
153  # Special case: if there's only one input, then return it as-is.
154  if len(rt_inputs) == 1 and not stack_values:
155    return rt_inputs[0]
156
157  # Check the rank (number of dimensions) of the input tensors.
158  ndims = None
159  for rt in rt_inputs:
160    if ndims is None:
161      ndims = rt.shape.ndims
162    else:
163      rt.shape.assert_has_rank(ndims)
164
165  out_ndims = ndims if (ndims is None or not stack_values) else ndims + 1
166  axis = array_ops.get_positive_axis(axis, out_ndims)
167
168  if stack_values and ndims == 1 and axis == 0:
169    return ragged_tensor.RaggedTensor.from_row_lengths(
170        values=array_ops.concat(rt_inputs, axis=0),
171        row_lengths=array_ops.concat([array_ops.shape(r) for r in rt_inputs],
172                                     axis=0))
173
174  # If all the inputs are Tensors, and we're combining the final dimension,
175  # then we can delegate to the tf.stack/tf.concat operation, and return a
176  # Tensor.
177  if all(not ragged_tensor.is_ragged(rt) for rt in rt_inputs):
178    if ndims is not None and (axis == out_ndims - 1 or axis == ndims - 1):
179      if stack_values:
180        return array_ops.stack(rt_inputs, axis)
181      else:
182        return array_ops.concat(rt_inputs, axis)
183
184  # Convert any Tensor inputs to RaggedTensors.  This makes it
185  # possible to concatenate Tensors and RaggedTensors together.
186  for i in range(len(rt_inputs)):
187    if not ragged_tensor.is_ragged(rt_inputs[i]):
188      rt_inputs[i] = ragged_tensor.RaggedTensor.from_tensor(
189          rt_inputs[i], ragged_rank=1, row_splits_dtype=row_splits_dtype)
190
191  # Convert the input tensors to all have the same ragged_rank.
192  ragged_rank = max(max(rt.ragged_rank for rt in rt_inputs), 1)
193  rt_inputs = [_increase_ragged_rank_to(rt, ragged_rank, row_splits_dtype)
194               for rt in rt_inputs]
195
196  if axis == 0:
197    return _ragged_stack_concat_axis_0(rt_inputs, stack_values)
198  elif axis == 1:
199    return _ragged_stack_concat_axis_1(rt_inputs, stack_values)
200  else:  # axis > 1: recurse.
201    values = [rt.values for rt in rt_inputs]
202    splits = [[rt_input.row_splits] for rt_input in rt_inputs]
203    with ops.control_dependencies(ragged_util.assert_splits_match(splits)):
204      return ragged_tensor.RaggedTensor.from_row_splits(
205          _ragged_stack_concat_helper(values, axis - 1, stack_values),
206          splits[0][0], validate=False)
207
208
209def _ragged_stack_concat_axis_0(rt_inputs, stack_values):
210  """Helper function to concatenate or stack ragged tensors along axis 0.
211
212  Args:
213    rt_inputs: A list of RaggedTensors, all with the same rank and ragged_rank.
214    stack_values: Boolean.  If true, then stack values; otherwise, concatenate
215      them.
216
217  Returns:
218    A RaggedTensor.
219  """
220  # Concatenate the inner values together.
221  flat_values = [rt.flat_values for rt in rt_inputs]
222  concatenated_flat_values = array_ops.concat(flat_values, axis=0)
223
224  # Concatenate the splits together for each ragged dimension (adjusting
225  # split offsets as necessary).
226  nested_splits = [rt.nested_row_splits for rt in rt_inputs]
227  ragged_rank = rt_inputs[0].ragged_rank
228  concatenated_nested_splits = [
229      _concat_ragged_splits([ns[dim]
230                             for ns in nested_splits])
231      for dim in range(ragged_rank)
232  ]
233
234  # If we are performing a stack operation, then add another splits.
235  if stack_values:
236    stack_lengths = array_ops.stack([rt.nrows() for rt in rt_inputs])
237    stack_splits = ragged_util.lengths_to_splits(stack_lengths)
238    concatenated_nested_splits.insert(0, stack_splits)
239
240  return ragged_tensor.RaggedTensor.from_nested_row_splits(
241      concatenated_flat_values, concatenated_nested_splits, validate=False)
242
243
244def _ragged_stack_concat_axis_1(rt_inputs, stack_values):
245  """Helper function to concatenate or stack ragged tensors along axis 1.
246
247  Args:
248    rt_inputs: A list of RaggedTensors, all with the same rank and ragged_rank.
249    stack_values: Boolean.  If true, then stack values; otherwise, concatenate
250      them.
251
252  Returns:
253    A RaggedTensor.
254  """
255  num_inputs = len(rt_inputs)
256
257  rt_nrows = rt_inputs[0].nrows()
258  nrows_msg = 'Input tensors have incompatible shapes.'
259  nrows_checks = [
260      check_ops.assert_equal(rt.nrows(), rt_nrows, message=nrows_msg)
261      for rt in rt_inputs[1:]
262  ]
263
264  with ops.control_dependencies(nrows_checks):
265    # Concatenate the inputs together to put them in a single ragged tensor.
266    concatenated_rt = _ragged_stack_concat_axis_0(rt_inputs, stack_values=False)
267
268    # Use ragged.gather to permute the rows of concatenated_rt.  In particular,
269    #   permuted_rt = [rt_inputs[0][0], ..., rt_inputs[N][0],
270    #                  rt_inputs[0][1], ..., rt_inputs[N][1],
271    #                      ...,
272    #                  rt_inputs[0][M], ..., rt_input[N][M]]
273    # where `N=num_inputs-1` and `M=rt_nrows-1`.
274    row_indices = math_ops.range(rt_nrows * num_inputs)
275    row_index_matrix = array_ops.reshape(row_indices, [num_inputs, -1])
276    transposed_row_index_matrix = array_ops.transpose(row_index_matrix)
277    row_permutation = array_ops.reshape(transposed_row_index_matrix, [-1])
278    permuted_rt = ragged_gather_ops.gather(concatenated_rt, row_permutation)
279
280    if stack_values:
281      # Add a new splits tensor to group together the values.
282      stack_splits = math_ops.range(0, rt_nrows * num_inputs + 1, num_inputs)
283      _copy_row_shape(rt_inputs, stack_splits)
284      return ragged_tensor.RaggedTensor.from_row_splits(
285          permuted_rt, stack_splits, validate=False)
286    else:
287      # Merge together adjacent rows by dropping the row-split indices that
288      # separate them.
289      concat_splits = permuted_rt.row_splits[::num_inputs]
290      _copy_row_shape(rt_inputs, concat_splits)
291      return ragged_tensor.RaggedTensor.from_row_splits(
292          permuted_rt.values, concat_splits, validate=False)
293
294
295def _copy_row_shape(rt_inputs, splits):
296  """Sets splits.shape to [rt[shape[0]+1] for each rt in rt_inputs."""
297  for rt in rt_inputs:
298    if rt.shape[0] is not None:
299      splits.set_shape(tensor_shape.TensorShape(rt.shape[0] + 1))
300
301
302def _increase_ragged_rank_to(rt_input, ragged_rank, row_splits_dtype):
303  """Adds ragged dimensions to `rt_input` so it has the desired ragged rank."""
304  if ragged_rank > 0:
305    if not ragged_tensor.is_ragged(rt_input):
306      rt_input = ragged_tensor.RaggedTensor.from_tensor(
307          rt_input, row_splits_dtype=row_splits_dtype)
308    if rt_input.ragged_rank < ragged_rank:
309      rt_input = rt_input.with_values(
310          _increase_ragged_rank_to(rt_input.values, ragged_rank - 1,
311                                   row_splits_dtype))
312  return rt_input
313
314
315def _concat_ragged_splits(splits_list):
316  """Concatenates a list of RaggedTensor splits to form a single splits."""
317  pieces = [splits_list[0]]
318  splits_offset = splits_list[0][-1]
319  for splits in splits_list[1:]:
320    pieces.append(splits[1:] + splits_offset)
321    splits_offset += splits[-1]
322  return array_ops.concat(pieces, axis=0)
323