• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Array operations for RaggedTensors."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21from tensorflow.python.framework import dtypes
22from tensorflow.python.framework import ops
23from tensorflow.python.framework import tensor_util
24from tensorflow.python.ops import array_ops
25from tensorflow.python.ops import check_ops
26from tensorflow.python.ops import math_ops
27from tensorflow.python.ops import sort_ops
28from tensorflow.python.ops.ragged import ragged_functional_ops
29from tensorflow.python.ops.ragged import ragged_math_ops
30from tensorflow.python.ops.ragged import ragged_tensor
31from tensorflow.python.ops.ragged import ragged_util
32from tensorflow.python.ops.ragged import segment_id_ops
33from tensorflow.python.util.tf_export import tf_export
34
35
36#===============================================================================
37# Masking
38#===============================================================================
39
40
41@tf_export('ragged.boolean_mask')
42def boolean_mask(data, mask, name=None):
43  """Applies a boolean mask to `data` without flattening the mask dimensions.
44
45  Returns a potentially ragged tensor that is formed by retaining the elements
46  in `data` where the corresponding value in `mask` is `True`.
47
48  * `output[a1...aA, i, b1...bB] = data[a1...aA, j, b1...bB]`
49
50     Where `j` is the `i`th `True` entry of `mask[a1...aA]`.
51
52  Note that `output` preserves the mask dimensions `a1...aA`; this differs
53  from `tf.boolean_mask`, which flattens those dimensions.
54
55  Args:
56    data: A potentially ragged tensor.
57    mask: A potentially ragged boolean tensor.  `mask`'s shape must be a prefix
58      of `data`'s shape.  `rank(mask)` must be known statically.
59    name: A name prefix for the returned tensor (optional).
60
61  Returns:
62    A potentially ragged tensor that is formed by retaining the elements in
63    `data` where the corresponding value in `mask` is `True`.
64
65    * `rank(output) = rank(data)`.
66    * `output.ragged_rank = max(data.ragged_rank, rank(mask) - 1)`.
67
68  Raises:
69    ValueError: if `rank(mask)` is not known statically; or if `mask.shape` is
70      not a prefix of `data.shape`.
71
72  #### Examples:
73
74  >>> # Aliases for True & False so data and mask line up.
75  >>> T, F = (True, False)
76
77  >>> tf.ragged.boolean_mask(  # Mask a 2D Tensor.
78  ...     data=[[1, 2, 3], [4, 5, 6], [7, 8, 9]],
79  ...     mask=[[T, F, T], [F, F, F], [T, F, F]]).to_list()
80  [[1, 3], [], [7]]
81
82  >>> tf.ragged.boolean_mask(  # Mask a 2D RaggedTensor.
83  ...     tf.ragged.constant([[1, 2, 3], [4], [5, 6]]),
84  ...     tf.ragged.constant([[F, F, T], [F], [T, T]])).to_list()
85  [[3], [], [5, 6]]
86
87  >>> tf.ragged.boolean_mask(  # Mask rows of a 2D RaggedTensor.
88  ...     tf.ragged.constant([[1, 2, 3], [4], [5, 6]]),
89  ...     tf.ragged.constant([True, False, True])).to_list()
90  [[1, 2, 3], [5, 6]]
91  """
92  with ops.name_scope(name, 'RaggedMask', [data, mask]):
93    # Convert inputs to tensors.
94    data = ragged_tensor.convert_to_tensor_or_ragged_tensor(data, name='data')
95    mask = ragged_tensor.convert_to_tensor_or_ragged_tensor(
96        mask, dtypes.bool, name='mask')
97    row_splits_dtype, (data, mask) = ragged_tensor.match_row_splits_dtypes(
98        data, mask, return_dtype=True)
99
100    # Get static rank of mask.
101    if mask.shape.ndims is None:
102      raise ValueError('mask.shape.ndims must be known statically.')
103    elif mask.shape.ndims == 0:
104      raise ValueError('mask cannot be scalar.')
105
106    # If mask is ragged, then recurse with a non-ragged mask.
107    if ragged_tensor.is_ragged(mask):
108      if not ragged_tensor.is_ragged(data):
109        data = ragged_tensor.RaggedTensor.from_tensor(
110            data, ragged_rank=mask.ragged_rank,
111            row_splits_dtype=mask.row_splits.dtype)
112      # Check that mask.nested_row_splits is a prefix of
113      # data.nested_row_splits.
114      splits_list = [
115          mask.nested_row_splits, data.nested_row_splits[:mask.ragged_rank]
116      ]
117      with ops.control_dependencies(
118          ragged_util.assert_splits_match(splits_list)):
119        # Strip off ragged `splits` until `mask` is non-ragged.  Keep the splits
120        # that we strip off in `splits`, so we can add them back on after
121        # we recursively mask the non-ragged data.
122        splits = []
123        while ragged_tensor.is_ragged(mask):
124          if mask.shape.ndims > 2:
125            splits.append(mask.row_splits)
126          else:
127            # Count the number of True mask values in each row to find the
128            # lengths of the filtered rows; then convert to splits.
129            int_mask = ragged_functional_ops.map_flat_values(
130                math_ops.cast, mask, dtype=row_splits_dtype)
131            masked_row_lengths = ragged_math_ops.reduce_sum(int_mask, axis=1)
132            splits.append(ragged_util.lengths_to_splits(masked_row_lengths))
133          mask = mask.values
134          data = data.values
135
136        # Recursively apply the nested non-ragged mask to the nested data.
137        masked_values = boolean_mask(data, mask)
138
139        # Add the ragged `splits` back to the result.
140        masked_values = ragged_tensor.RaggedTensor.from_nested_row_splits(
141            masked_values, splits, validate=False)
142
143        return masked_values
144
145    # If mask is non-ragged and has rank 1, and data is ragged, then build a
146    # ragged tensor with the indicated rows.
147    elif ragged_tensor.is_ragged(data) and mask.shape.ndims == 1:
148      # Get the masked splits: first get the length of each row, then filter
149      # out the rows that we are deleting, and convert that filtered set of
150      # masks back to a splits tensor.
151      lengths = data.row_lengths()
152      masked_lengths = array_ops.boolean_mask(lengths, mask)
153      masked_splits = ragged_util.lengths_to_splits(masked_lengths)
154
155      # Get the masked values: first get row ids corresponding to each
156      # value, then use tf.gather to build a boolean mask that's false for
157      # values that come from rows that we are deleting, and use that mask to
158      # construct the masked values tensor.
159      segment_ids = segment_id_ops.row_splits_to_segment_ids(data.row_splits)
160      segment_mask = array_ops.gather(mask, segment_ids)
161      masked_values = boolean_mask(data.values, segment_mask)
162
163      return ragged_tensor.RaggedTensor.from_row_splits(masked_values,
164                                                        masked_splits,
165                                                        validate=False)
166
167    # If mask is non-ragged and has rank>1, then convert it to be ragged,
168    # with a ragged rank matching data.
169    if ragged_tensor.is_ragged(data):
170      mask = ragged_tensor.RaggedTensor.from_tensor(
171          mask, ragged_rank=min(data.ragged_rank, mask.shape.ndims - 1),
172          row_splits_dtype=data.row_splits.dtype)
173      return boolean_mask(data, mask)
174
175    # Otherwise, data and mask are both `Tensor`s.
176    else:
177      # Apply `boolean_mask` to get the masked values.
178      masked_values = array_ops.boolean_mask(data, mask)
179
180      if mask.shape.ndims >= 2:
181        # Add the innermost ragged dimension.  For each innermost cell, get the
182        # number of values it contains.  Then flatten that to get a list of
183        # cell lengths, and convert it to splits.  Finally, combine the splits
184        # and values to get the innermost ragged tensor.
185        masked_lengths = math_ops.count_nonzero(mask, axis=-1,
186                                                dtype=row_splits_dtype)
187        flattened_masked_lengths = array_ops.reshape(masked_lengths, [-1])
188        masked_values = ragged_tensor.RaggedTensor.from_row_lengths(
189            masked_values, flattened_masked_lengths, validate=False)
190
191        # Wrap remaining ragged dimensions.
192        if mask.shape.ndims > 2:
193          mask_shape = array_ops.shape(mask, out_type=row_splits_dtype)
194          split_size = math_ops.cumprod(mask_shape) + 1
195          for dim in range(mask.shape.ndims - 3, -1, -1):
196            elt_size = mask_shape[dim + 1]
197            masked_splits = math_ops.range(split_size[dim]) * elt_size
198            masked_values = ragged_tensor.RaggedTensor.from_row_splits(
199                masked_values, masked_splits, validate=False)
200
201      return masked_values
202
203
204#===============================================================================
205# Tiling
206#===============================================================================
207def tile(input, multiples, name=None):  # pylint: disable=redefined-builtin
208  """Constructs a `RaggedTensor` by tiling a given `RaggedTensor`.
209
210  The values of `input` are replicated `multiples[i]` times along the
211  `i`th dimension (for each dimension `i`).  For every dimension `axis` in
212  `input`, the length of each output element in that dimension is the
213  length of corresponding input element multiplied by `multiples[axis]`.
214
215  Args:
216    input: A `RaggedTensor`.
217    multiples: A 1-D integer `Tensor`.  Length must be the same as the number of
218      dimensions in `input`.
219    name: A name for the operation (optional).
220
221  Returns:
222    A `RaggedTensor` with the same type, rank, and ragged_rank as `input`.
223
224  #### Example:
225
226  >>> rt = tf.ragged.constant([[1, 2], [3]])
227  >>> tf.tile(rt, [3, 2]).to_list()
228  [[1, 2, 1, 2], [3, 3], [1, 2, 1, 2], [3, 3], [1, 2, 1, 2], [3, 3]]
229  """
230  with ops.name_scope(name, 'RaggedTile', [input, multiples]):
231    input = ragged_tensor.convert_to_tensor_or_ragged_tensor(
232        input, name='input')
233    if not ragged_tensor.is_ragged(input):
234      return array_ops.tile(input, multiples, name)
235    multiples = ragged_util.convert_to_int_tensor(
236        multiples, name='multiples', dtype=input.row_splits.dtype)
237    multiples.shape.assert_has_rank(1)
238
239    # If the constant value of `multiples` is available, then we can use it
240    # to skip tiling dimensions where `multiples=1`.
241    const_multiples = tensor_util.constant_value(multiples)
242
243    return ragged_tensor.RaggedTensor.from_nested_row_splits(
244        _tile_ragged_values(input, multiples, const_multiples),
245        _tile_ragged_splits(input, multiples, const_multiples),
246        validate=False)
247
248
249def _tile_ragged_values(rt_input, multiples, const_multiples=None):
250  """Builds flat_values tensor for a tiled `RaggedTensor`.
251
252  Returns a tensor that repeats the values in
253  `rt_input.flat_values` in the
254  appropriate pattern to construct a `RaggedTensor` that tiles `rt_input` as
255  specified by `multiples`.
256
257  Args:
258    rt_input: The `RaggedTensor` whose values should be repeated.
259    multiples: A 1-D integer `tensor`, indicating how many times each dimension
260      should be repeated.
261    const_multiples: Optional constant value for multiples.  Used to skip tiling
262      dimensions where `multiples=1`.
263
264  Returns:
265    A `Tensor` with the same type and rank as `rt_input.flat_values`.
266
267  #### Example:
268
269  >>> rt = tf.ragged.constant([[1, 2], [3]])
270  >>> _tile_ragged_values(rt, tf.constant([3, 2])).numpy()
271  array([1, 2, 1, 2, 3, 3, 1, 2, 1, 2, 3, 3, 1, 2, 1, 2, 3, 3], dtype=int32)
272  """
273  ragged_rank = rt_input.ragged_rank
274  nested_splits = rt_input.nested_row_splits
275
276  # Pointers to the values in `rt_input.flat_values`.
277  inner_value_ids = math_ops.range(nested_splits[-1][-1])
278
279  # For each ragged dimension (working from the innermost to outermost),
280  # expand `inner_value_ids` as necessary to tile that dimension.
281  prev_splits = None
282  for axis in range(ragged_rank, 0, -1):
283    # Ragged splits for this dimension.
284    splits = nested_splits[axis - 1]
285
286    # Adjust splits so they point into `inner_value_ids` (instead of just
287    # pointing into the next dimension's values).
288    if prev_splits is not None:  # Not the first pass through the loop.
289      splits = array_ops.gather(prev_splits * multiples[axis + 1], splits)
290
291    # Repeat each element in this ragged dimension `multiples[axis]` times.
292    if const_multiples is None or const_multiples[axis] != 1:
293      inner_value_ids = ragged_util.repeat_ranges(inner_value_ids, splits,
294                                                  multiples[axis])
295
296    prev_splits = splits
297
298  # Gather the tiled inner values.
299  ragged_tiled_values = array_ops.gather(rt_input.flat_values, inner_value_ids)
300
301  # Tile the flat_values for the uniform dimensions (i.e., for `axis=0` plus
302  # `axis=range(ragged_rank, rank)`).
303  inner_repeats = array_ops.concat([multiples[:1], multiples[ragged_rank + 1:]],
304                                   axis=0)
305  return array_ops.tile(ragged_tiled_values, inner_repeats)
306
307
308def _tile_ragged_splits(rt_input, multiples, const_multiples=None):
309  """Builds nested_split tensors for a tiled `RaggedTensor`.
310
311  Returns a list of split tensors that can be used to construct the
312  `RaggedTensor` that tiles `rt_input` as specified by `multiples`.
313
314  Args:
315    rt_input: The `RaggedTensor` that is being tiled.
316    multiples: A 1-D integer `tensor`, indicating how many times each dimension
317      should be repeated.
318    const_multiples: Optional constant value for multiples.  Used to skip tiling
319      dimensions where `multiples=1`.
320
321  Returns:
322    A list of 1-D integer `Tensor`s (one for each ragged dimension in
323    `rt_input`).
324
325  #### Example:
326
327  >>> rt = tf.ragged.constant([[1, 2], [3]])
328  >>> _tile_ragged_splits(rt, [3, 2])
329  [<tf.Tensor: shape=(7,), dtype=int64,
330  numpy=array([ 0,  4,  6, 10, 12, 16, 18])>]
331  """
332  ragged_rank = rt_input.ragged_rank
333  nested_splits = rt_input.nested_row_splits
334
335  # projected_splits[src_axis, dst_axis] contains the split points that divide
336  # the rows from src_axis in the list of dst_axis values.  E.g.,
337  # projected_splits[i, i] = nested_splits[i], and
338  # projected_splits[i, i+1] = gather(nested_splits[i+1], nested_splits[i]).
339  projected_splits = [{i: nested_splits[i]} for i in range(ragged_rank)]
340  for src_axis in range(ragged_rank):
341    for dst_axis in range(src_axis + 1, ragged_rank - 1):
342      projected_splits[src_axis][dst_axis] = array_ops.gather(
343          nested_splits[dst_axis],
344          projected_splits[src_axis][dst_axis - 1])
345
346  # For each ragged dimension: nested_splits[axis] -> result_splits[axis].
347  result_splits = []
348  for axis in range(ragged_rank):
349    # Get the length of each row for the input tensor for this dimension.
350    input_lengths = nested_splits[axis][1:] - nested_splits[axis][:-1]
351
352    # Multiply those lengths by the `multiples` of dimension axis+1, since
353    # each value will be repeated that number of times.
354    output_lengths = input_lengths * multiples[axis + 1]
355
356    # Repeat ranges of the row lengths as necessary for them to be tiled in
357    # each ragged dimension `d < axis`.  (Start with dimension d=axis-1, and
358    # work our way up to dimension d=0.)
359    repeats = 1
360    for d in range(axis - 1, -1, -1):
361      if const_multiples is None or const_multiples[d + 1] != 1:
362        splits = projected_splits[d][axis - 1] * repeats
363        output_lengths = ragged_util.repeat_ranges(output_lengths, splits,
364                                                   multiples[d + 1])
365      repeats *= multiples[d + 1]
366
367    # Tile splits for the outermost (uniform) dimension.
368    output_lengths = array_ops.tile(output_lengths, multiples[:1])
369
370    # Convert to splits.
371    result_splits.append(ragged_util.lengths_to_splits(output_lengths))
372
373  return result_splits
374
375
376#===============================================================================
377# Reshaping
378#===============================================================================
379
380
381def expand_dims(input, axis, name=None):  # pylint: disable=redefined-builtin
382  """Inserts a dimension with shape 1 into a potentially ragged tensor's shape.
383
384  Given a potentially ragged tenor `input`, this operation inserts a
385  dimension with size 1 at the dimension `axis` of `input`'s shape.
386
387  * If `input` is a `Tensor`, then this is equivalent to
388    `tf.expand_dims`.
389  * If `input` is ragged, and `axis=0`, then the new dimension will be
390    uniform; but the previously outermost dimension will become ragged.
391  * If `input` is ragged, and `0 < axis < input.ragged_rank`, then the
392    new dimension will be ragged.
393  * If `input` is ragged, and axis >= input.ragged_rank`, then the new
394    dimension will be uniform.
395
396  The following table gives some examples showing how `ragged.expand_dims`
397  impacts the shapes of different input tensors.  Ragged dimensions are
398  indicated by enclosing them in parentheses.
399
400  input.shape             | axis | result.shape
401  ----------------------- | ---- | -----------------------------
402  `[D1, D2]`              |  `0` | `[1, D1, D2]`
403  `[D1, D2]`              |  `1` | `[D1, 1, D2]`
404  `[D1, D2]`              |  `2` | `[D1, D2, 1]`
405  `[D1, (D2), (D3), D4]`  |  `0` | `[1, (D1), (D2), (D3), D4]`
406  `[D1, (D2), (D3), D4]`  |  `1` | `[D1, (1), (D2), (D3), D4]`
407  `[D1, (D2), (D3), D4]`  |  `2` | `[D1, (D2), (1), (D3), D4]`
408  `[D1, (D2), (D3), D4]`  |  `3` | `[D1, (D2), (D3), 1, D4]`
409  `[D1, (D2), (D3), D4]`  |  `4` | `[D1, (D2), (D3), D4, 1]`
410
411  Args:
412    input: The potentially tensor that should be expanded with a new
413      dimension.
414    axis: An integer constant indicating where the new dimension should be
415      inserted.
416    name: A name for the operation (optional).
417
418  Returns:
419    A tensor with the same values as `input`, with an added dimension of
420    size 1 at `axis`.
421
422  #### Examples:
423
424  >>> rt = tf.ragged.constant([[1, 2], [3]])
425  >>> print(rt.shape)
426  (2, None)
427
428  >>> expanded = tf.expand_dims(rt, axis=0)
429  >>> print(expanded.shape, expanded)
430  (1, None, None) <tf.RaggedTensor [[[1, 2], [3]]]>
431
432  >>> expanded = tf.expand_dims(rt, axis=1)
433  >>> print(expanded.shape, expanded)
434  (2, None, None) <tf.RaggedTensor [[[1, 2]], [[3]]]>
435
436  >>> expanded = tf.expand_dims(rt, axis=2)
437  >>> print(expanded.shape, expanded)
438  (2, None, 1) <tf.RaggedTensor [[[1], [2]], [[3]]]>
439  """
440  with ops.name_scope(name, 'RaggedExpandDims', [input]):
441    input = ragged_tensor.convert_to_tensor_or_ragged_tensor(
442        input, name='input')
443
444    if not ragged_tensor.is_ragged(input):
445      return array_ops.expand_dims(input, axis)
446
447    ndims = None if input.shape.ndims is None else input.shape.ndims + 1
448    axis = ragged_util.get_positive_axis(axis, ndims)
449    if axis == 0:
450      values = input
451      splits = array_ops.stack([0, input.nrows()])
452    elif axis == 1:
453      values = input
454      splits = math_ops.range(input.nrows() + 1)
455    else:
456      values = expand_dims(input.values, axis - 1)
457      splits = input.row_splits
458
459    return ragged_tensor.RaggedTensor.from_row_splits(values, splits,
460                                                      validate=False)
461
462
463#===============================================================================
464# RaggedTensor Size
465#===============================================================================
466
467
468def size(input, out_type=dtypes.int32, name=None):  # pylint: disable=redefined-builtin
469  """Returns the size of a potentially ragged tensor.
470
471  The size of a ragged tensor is the size of its inner values.
472
473  #### Example:
474
475  >>> tf.size(tf.ragged.constant([[1, 2], [3]])).numpy()
476  3
477
478  Args:
479    input: A potentially ragged `Tensor`.
480    out_type: The numeric output type for the operation.
481    name: A name for the operation (optional).
482
483  Returns:
484    A Tensor of type `out_type`.
485  """
486  if ragged_tensor.is_ragged(input):
487    return array_ops.size(input.flat_values, out_type=out_type, name=name)
488  else:
489    return array_ops.size(input, out_type=out_type, name=name)
490
491
492#===============================================================================
493# ragged.rank
494#===============================================================================
495def rank(input, name=None):  # pylint: disable=redefined-builtin
496  """Returns the rank of a RaggedTensor.
497
498  Returns a 0-D `int32` `Tensor` representing the rank of `input`.
499
500  #### Example:
501
502  >>> # shape of tensor 't' is [2, None, None]
503  >>> t = tf.ragged.constant([[[1], [2, 2]], [[3, 3, 3], [4, 4, 4, 4]]])
504  >>> tf.rank(t).numpy()
505  3
506
507  Args:
508    input: A `RaggedTensor`
509    name: A name for the operation (optional).
510
511  Returns:
512    A `Tensor` of type `int32`.
513  """
514  with ops.name_scope(name, 'RaggedRank', [input]) as name:
515    if not ragged_tensor.is_ragged(input):
516      return array_ops.rank(input, name)
517
518    return input.ragged_rank + array_ops.rank(input.flat_values)
519
520
521#===============================================================================
522# ragged.one_hot
523#===============================================================================
524def ragged_one_hot(indices,
525                   depth,
526                   on_value=None,
527                   off_value=None,
528                   axis=None,
529                   dtype=None,
530                   name=None):
531  """Applies tf.one_hot along the values of a RaggedTensor."""
532  with ops.name_scope(name, 'RaggedOneHot', [indices]):
533    indices = ragged_tensor.convert_to_tensor_or_ragged_tensor(
534        indices, name='indices')
535    if axis is not None:
536      axis = ragged_util.get_positive_axis(axis, indices.shape.ndims)
537      if axis < indices.ragged_rank:
538        raise ValueError('axis may not be less than indices.ragged_rank.')
539    return indices.with_flat_values(
540        array_ops.one_hot(indices.flat_values, depth, on_value, off_value, axis,
541                          dtype, name))
542
543
544#===============================================================================
545# ragged.stack_dynamic_partitions
546#===============================================================================
547@tf_export('ragged.stack_dynamic_partitions')
548def stack_dynamic_partitions(data, partitions, num_partitions, name=None):
549  """Stacks dynamic partitions of a Tensor or RaggedTensor.
550
551  Returns a RaggedTensor `output` with `num_partitions` rows, where the row
552  `output[i]` is formed by stacking all slices `data[j1...jN]` such that
553  `partitions[j1...jN] = i`.  Slices of `data` are stacked in row-major
554  order.
555
556  If `num_partitions` is an `int` (not a `Tensor`), then this is equivalent to
557  `tf.ragged.stack(tf.dynamic_partition(data, partitions, num_partitions))`.
558
559  #### Example:
560
561  >>> data           = ['a', 'b', 'c', 'd', 'e']
562  >>> partitions     = [  3,   0,   2,   2,   3]
563  >>> num_partitions = 5
564  >>> tf.ragged.stack_dynamic_partitions(data, partitions, num_partitions)
565  <tf.RaggedTensor [[b'b'], [], [b'c', b'd'], [b'a', b'e'], []]>
566
567  Args:
568    data: A `Tensor` or `RaggedTensor` containing the values to stack.
569    partitions: An `int32` or `int64` `Tensor` or `RaggedTensor` specifying the
570      partition that each slice of `data` should be added to.
571      `partitions.shape` must be a prefix of `data.shape`.  Values must be
572      greater than or equal to zero, and less than `num_partitions`.
573      `partitions` is not required to be sorted.
574    num_partitions: An `int32` or `int64` scalar specifying the number of
575      partitions to output.  This determines the number of rows in `output`.
576    name: A name prefix for the returned tensor (optional).
577
578  Returns:
579    A `RaggedTensor` containing the stacked partitions.  The returned tensor
580    has the same dtype as `data`, and its shape is
581    `[num_partitions, (D)] + data.shape[partitions.rank:]`, where `(D)` is a
582    ragged dimension whose length is the number of data slices stacked for
583    each `partition`.
584  """
585  with ops.name_scope(name, 'SegmentStack', [data, partitions, num_partitions]):
586    # Convert inputs to tensors.
587    data = ragged_tensor.convert_to_tensor_or_ragged_tensor(data, name='data')
588    row_splits_dtype = (
589        data.row_splits.dtype
590        if isinstance(data, ragged_tensor.RaggedTensor) else None)
591    partitions = ragged_tensor.convert_to_tensor_or_ragged_tensor(
592        partitions, name='partitions', preferred_dtype=row_splits_dtype)
593    num_partitions = ops.convert_to_tensor(
594        num_partitions, name='num_partitions', preferred_dtype=partitions.dtype)
595    if row_splits_dtype is not None:
596      partitions = math_ops.cast(partitions, row_splits_dtype)
597    num_partitions = math_ops.cast(num_partitions, partitions.dtype)
598
599    # Sanity-checks for shapes.
600    partitions_rank = partitions.shape.ndims
601    if partitions_rank is None:
602      raise ValueError('partitions must have known rank.')
603    num_partitions.shape.assert_has_rank(0)
604    partitions.shape.assert_is_compatible_with(data.shape[:partitions_rank])
605
606    if partitions_rank == 0:
607      # If partitions is a scalar, then just create a RaggedTensor containing
608      # that single the complete `data` value in the specified row.
609      return ragged_tensor.RaggedTensor.from_value_rowids(
610          values=array_ops.stack([data]),
611          value_rowids=array_ops.stack([partitions]),
612          nrows=num_partitions,
613          validate=False)
614
615    elif partitions_rank == 1:
616      # If partitions is a vector (the typical case): we can just use data and
617      # partitions as the `values` and `value_rowids` for `from_value_rowids`,
618      # as long as we sort them first.
619      permutation = sort_ops.argsort(partitions, stable=True)
620      value_rowids = array_ops.gather(partitions, permutation)
621      values = array_ops.gather(data, permutation)
622      check = check_ops.assert_less(
623          value_rowids[-1:],
624          num_partitions,
625          message='partitions must be less than num_partitions')
626      with ops.control_dependencies([check]):
627        return ragged_tensor.RaggedTensor.from_value_rowids(
628            values, value_rowids, nrows=num_partitions, validate=False)
629
630    else:
631      # Handle higher-dimensional partitions via recursion.
632      if not isinstance(data, ragged_tensor.RaggedTensor):
633        data = ragged_tensor.RaggedTensor.from_tensor(
634            data, row_splits_dtype=partitions.dtype, ragged_rank=1)
635      if not isinstance(partitions, ragged_tensor.RaggedTensor):
636        partitions = ragged_tensor.RaggedTensor.from_tensor(
637            partitions,
638            row_splits_dtype=partitions.dtype,
639            ragged_rank=max(data.ragged_rank, partitions_rank - 1))
640      check = check_ops.assert_equal(
641          data.row_splits,
642          partitions.row_splits,
643          message='data and partitions have incompatible ragged shapes')
644      with ops.control_dependencies([check]):
645        return stack_dynamic_partitions(data.values, partitions.values,
646                                        num_partitions)
647
648
649#===============================================================================
650# Reverse
651#===============================================================================
652def reverse(tensor, axis, name=None):
653  """Reverses a RaggedTensor along the specified axes.
654
655  #### Example:
656
657  >>> data = tf.ragged.constant([
658  ...   [[1, 2], [3, 4]], [[5, 6]], [[7, 8], [9, 10], [11, 12]]])
659  >>> tf.reverse(data, axis=[0, 2])
660  <tf.RaggedTensor [[[8, 7], [10, 9], [12, 11]], [[6, 5]], [[2, 1], [4, 3]]]>
661
662  Args:
663    tensor: A 'RaggedTensor' to reverse.
664    axis: A list or tuple of 'int' or a constant 1D 'tf.Tensor'. The indices
665      of the axes to reverse.
666    name: A name prefix for the returned tensor (optional).
667
668  Returns:
669    A 'RaggedTensor'.
670  """
671  type_error_msg = ('`axis` must be a list of int or a constant tensor'
672                    'when reversing axes in a ragged tensor')
673
674  with ops.name_scope(name, 'Reverse', [tensor, axis]):
675    if isinstance(axis, ops.Tensor):
676      axis = tensor_util.constant_value(axis)
677      if axis is None:
678        raise TypeError(type_error_msg)
679    elif not (isinstance(axis, (list, tuple)) and
680              all(isinstance(dim, int) for dim in axis)):
681      raise TypeError(type_error_msg)
682
683    tensor = ragged_tensor.convert_to_tensor_or_ragged_tensor(
684        tensor, name='tensor')
685
686    # Allow usage of negative values to specify innermost axes.
687    axis = [ragged_util.get_positive_axis(dim, tensor.shape.rank)
688            for dim in axis]
689
690    # We only need to slice up to the max axis. If the axis list
691    # is empty, it should be 0.
692    slices = [slice(None)] * (max(axis) + 1 if axis else 0)
693
694    for dim in axis:
695      slices[dim] = slice(None, None, -1)
696
697    return tensor[tuple(slices)]
698