• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Array operations for RaggedTensors."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21from tensorflow.python.framework import dtypes
22from tensorflow.python.framework import ops
23from tensorflow.python.framework import sparse_tensor
24from tensorflow.python.framework import tensor_util
25from tensorflow.python.ops import array_ops
26from tensorflow.python.ops import check_ops
27from tensorflow.python.ops import gen_ragged_array_ops
28from tensorflow.python.ops import math_ops
29from tensorflow.python.ops import sort_ops
30from tensorflow.python.ops.ragged import ragged_functional_ops
31from tensorflow.python.ops.ragged import ragged_math_ops
32from tensorflow.python.ops.ragged import ragged_tensor
33from tensorflow.python.ops.ragged import ragged_util
34from tensorflow.python.ops.ragged import segment_id_ops
35from tensorflow.python.util import dispatch
36from tensorflow.python.util.tf_export import tf_export
37
38#===============================================================================
39# Masking
40#===============================================================================
41
42
43@tf_export('ragged.boolean_mask')
44@dispatch.add_dispatch_support
45def boolean_mask(data, mask, name=None):
46  """Applies a boolean mask to `data` without flattening the mask dimensions.
47
48  Returns a potentially ragged tensor that is formed by retaining the elements
49  in `data` where the corresponding value in `mask` is `True`.
50
51  * `output[a1...aA, i, b1...bB] = data[a1...aA, j, b1...bB]`
52
53     Where `j` is the `i`th `True` entry of `mask[a1...aA]`.
54
55  Note that `output` preserves the mask dimensions `a1...aA`; this differs
56  from `tf.boolean_mask`, which flattens those dimensions.
57
58  Args:
59    data: A potentially ragged tensor.
60    mask: A potentially ragged boolean tensor.  `mask`'s shape must be a prefix
61      of `data`'s shape.  `rank(mask)` must be known statically.
62    name: A name prefix for the returned tensor (optional).
63
64  Returns:
65    A potentially ragged tensor that is formed by retaining the elements in
66    `data` where the corresponding value in `mask` is `True`.
67
68    * `rank(output) = rank(data)`.
69    * `output.ragged_rank = max(data.ragged_rank, rank(mask) - 1)`.
70
71  Raises:
72    ValueError: if `rank(mask)` is not known statically; or if `mask.shape` is
73      not a prefix of `data.shape`.
74
75  #### Examples:
76
77  >>> # Aliases for True & False so data and mask line up.
78  >>> T, F = (True, False)
79
80  >>> tf.ragged.boolean_mask(  # Mask a 2D Tensor.
81  ...     data=[[1, 2, 3], [4, 5, 6], [7, 8, 9]],
82  ...     mask=[[T, F, T], [F, F, F], [T, F, F]]).to_list()
83  [[1, 3], [], [7]]
84
85  >>> tf.ragged.boolean_mask(  # Mask a 2D RaggedTensor.
86  ...     tf.ragged.constant([[1, 2, 3], [4], [5, 6]]),
87  ...     tf.ragged.constant([[F, F, T], [F], [T, T]])).to_list()
88  [[3], [], [5, 6]]
89
90  >>> tf.ragged.boolean_mask(  # Mask rows of a 2D RaggedTensor.
91  ...     tf.ragged.constant([[1, 2, 3], [4], [5, 6]]),
92  ...     tf.ragged.constant([True, False, True])).to_list()
93  [[1, 2, 3], [5, 6]]
94  """
95  with ops.name_scope(name, 'RaggedMask', [data, mask]):
96    # Convert inputs to tensors.
97    data = ragged_tensor.convert_to_tensor_or_ragged_tensor(data, name='data')
98    mask = ragged_tensor.convert_to_tensor_or_ragged_tensor(
99        mask, dtypes.bool, name='mask')
100    row_splits_dtype, (data, mask) = ragged_tensor.match_row_splits_dtypes(
101        data, mask, return_dtype=True)
102
103    # Get static rank of mask.
104    if mask.shape.ndims is None:
105      raise ValueError('mask.shape.ndims must be known statically.')
106    elif mask.shape.ndims == 0:
107      raise ValueError('mask cannot be scalar.')
108
109    # If mask is ragged, then recurse with a non-ragged mask.
110    if ragged_tensor.is_ragged(mask):
111      if not ragged_tensor.is_ragged(data):
112        data = ragged_tensor.RaggedTensor.from_tensor(
113            data,
114            ragged_rank=mask.ragged_rank,
115            row_splits_dtype=mask.row_splits.dtype)
116      # Check that mask.nested_row_splits is a prefix of
117      # data.nested_row_splits.
118      splits_list = [
119          mask.nested_row_splits, data.nested_row_splits[:mask.ragged_rank]
120      ]
121      with ops.control_dependencies(
122          ragged_util.assert_splits_match(splits_list)):
123        # Strip off ragged `splits` until `mask` is non-ragged.  Keep the splits
124        # that we strip off in `splits`, so we can add them back on after
125        # we recursively mask the non-ragged data.
126        splits = []
127        while ragged_tensor.is_ragged(mask):
128          if mask.shape.ndims > 2:
129            splits.append(mask.row_splits)
130          else:
131            # Count the number of True mask values in each row to find the
132            # lengths of the filtered rows; then convert to splits.
133            int_mask = ragged_functional_ops.map_flat_values(
134                math_ops.cast, mask, dtype=row_splits_dtype)
135            masked_row_lengths = ragged_math_ops.reduce_sum(int_mask, axis=1)
136            splits.append(ragged_util.lengths_to_splits(masked_row_lengths))
137          mask = mask.values
138          data = data.values
139
140        # Recursively apply the nested non-ragged mask to the nested data.
141        masked_values = boolean_mask(data, mask)
142
143        # Add the ragged `splits` back to the result.
144        masked_values = ragged_tensor.RaggedTensor.from_nested_row_splits(
145            masked_values, splits, validate=False)
146
147        return masked_values
148
149    # If mask is non-ragged and has rank 1, and data is ragged, then build a
150    # ragged tensor with the indicated rows.
151    elif ragged_tensor.is_ragged(data) and mask.shape.ndims == 1:
152      # Get the masked splits: first get the length of each row, then filter
153      # out the rows that we are deleting, and convert that filtered set of
154      # masks back to a splits tensor.
155      lengths = data.row_lengths()
156      masked_lengths = array_ops.boolean_mask(lengths, mask)
157      masked_splits = ragged_util.lengths_to_splits(masked_lengths)
158
159      # Get the masked values: first get row ids corresponding to each
160      # value, then use tf.gather to build a boolean mask that's false for
161      # values that come from rows that we are deleting, and use that mask to
162      # construct the masked values tensor.
163      segment_ids = segment_id_ops.row_splits_to_segment_ids(data.row_splits)
164      segment_mask = array_ops.gather(mask, segment_ids)
165      masked_values = boolean_mask(data.values, segment_mask)
166
167      return ragged_tensor.RaggedTensor.from_row_splits(
168          masked_values, masked_splits, validate=False)
169
170    # If mask is non-ragged and has rank>1, then convert it to be ragged,
171    # with a ragged rank matching data.
172    if ragged_tensor.is_ragged(data):
173      mask = ragged_tensor.RaggedTensor.from_tensor(
174          mask,
175          ragged_rank=min(data.ragged_rank, mask.shape.ndims - 1),
176          row_splits_dtype=data.row_splits.dtype)
177      return boolean_mask(data, mask)
178
179    # Otherwise, data and mask are both `Tensor`s.
180    else:
181      # Apply `boolean_mask` to get the masked values.
182      masked_values = array_ops.boolean_mask(data, mask)
183
184      if mask.shape.ndims >= 2:
185        # Add the innermost ragged dimension.  For each innermost cell, get the
186        # number of values it contains.  Then flatten that to get a list of
187        # cell lengths, and convert it to splits.  Finally, combine the splits
188        # and values to get the innermost ragged tensor.
189        masked_lengths = math_ops.count_nonzero(
190            mask, axis=-1, dtype=row_splits_dtype)
191        flattened_masked_lengths = array_ops.reshape(masked_lengths, [-1])
192        masked_values = ragged_tensor.RaggedTensor.from_row_lengths(
193            masked_values, flattened_masked_lengths, validate=False)
194
195        # Wrap remaining ragged dimensions.
196        if mask.shape.ndims > 2:
197          mask_shape = array_ops.shape(mask, out_type=row_splits_dtype)
198          split_size = math_ops.cumprod(mask_shape) + 1
199          for dim in range(mask.shape.ndims - 3, -1, -1):
200            elt_size = mask_shape[dim + 1]
201            masked_splits = math_ops.range(split_size[dim]) * elt_size
202            masked_values = ragged_tensor.RaggedTensor.from_row_splits(
203                masked_values, masked_splits, validate=False)
204
205      return masked_values
206
207
208#===============================================================================
209# Tiling
210#===============================================================================
211def tile(input, multiples, name=None):  # pylint: disable=redefined-builtin
212  """Constructs a `RaggedTensor` by tiling a given `RaggedTensor`.
213
214  The values of `input` are replicated `multiples[i]` times along the
215  `i`th dimension (for each dimension `i`).  For every dimension `axis` in
216  `input`, the length of each output element in that dimension is the
217  length of corresponding input element multiplied by `multiples[axis]`.
218
219  Args:
220    input: A `RaggedTensor`.
221    multiples: A 1-D integer `Tensor`.  Length must be the same as the number of
222      dimensions in `input`.
223    name: A name for the operation (optional).
224
225  Returns:
226    A `RaggedTensor` with the same type, rank, and ragged_rank as `input`.
227
228  #### Example:
229
230  >>> rt = tf.ragged.constant([[1, 2], [3]])
231  >>> tf.tile(rt, [3, 2]).to_list()
232  [[1, 2, 1, 2], [3, 3], [1, 2, 1, 2], [3, 3], [1, 2, 1, 2], [3, 3]]
233  """
234  with ops.name_scope(name, 'RaggedTile', [input, multiples]):
235    input = ragged_tensor.convert_to_tensor_or_ragged_tensor(
236        input, name='input')
237    if not ragged_tensor.is_ragged(input):
238      return array_ops.tile(input, multiples, name)
239    multiples = ragged_util.convert_to_int_tensor(
240        multiples, name='multiples', dtype=input.row_splits.dtype)
241    multiples.shape.assert_has_rank(1)
242
243    # If the constant value of `multiples` is available, then we can use it
244    # to skip tiling dimensions where `multiples=1`.
245    const_multiples = tensor_util.constant_value(multiples)
246
247    return ragged_tensor.RaggedTensor.from_nested_row_splits(
248        _tile_ragged_values(input, multiples, const_multiples),
249        _tile_ragged_splits(input, multiples, const_multiples),
250        validate=False)
251
252
253def _tile_ragged_values(rt_input, multiples, const_multiples=None):
254  """Builds flat_values tensor for a tiled `RaggedTensor`.
255
256  Returns a tensor that repeats the values in
257  `rt_input.flat_values` in the
258  appropriate pattern to construct a `RaggedTensor` that tiles `rt_input` as
259  specified by `multiples`.
260
261  Args:
262    rt_input: The `RaggedTensor` whose values should be repeated.
263    multiples: A 1-D integer `tensor`, indicating how many times each dimension
264      should be repeated.
265    const_multiples: Optional constant value for multiples.  Used to skip tiling
266      dimensions where `multiples=1`.
267
268  Returns:
269    A `Tensor` with the same type and rank as `rt_input.flat_values`.
270
271  #### Example:
272
273  >>> rt = tf.ragged.constant([[1, 2], [3]])
274  >>> _tile_ragged_values(rt, tf.constant([3, 2])).numpy()
275  array([1, 2, 1, 2, 3, 3, 1, 2, 1, 2, 3, 3, 1, 2, 1, 2, 3, 3], dtype=int32)
276  """
277  ragged_rank = rt_input.ragged_rank
278  nested_splits = rt_input.nested_row_splits
279
280  # Pointers to the values in `rt_input.flat_values`.
281  inner_value_ids = math_ops.range(nested_splits[-1][-1])
282
283  # For each ragged dimension (working from the innermost to outermost),
284  # expand `inner_value_ids` as necessary to tile that dimension.
285  prev_splits = None
286  for axis in range(ragged_rank, 0, -1):
287    # Ragged splits for this dimension.
288    splits = nested_splits[axis - 1]
289
290    # Adjust splits so they point into `inner_value_ids` (instead of just
291    # pointing into the next dimension's values).
292    if prev_splits is not None:  # Not the first pass through the loop.
293      splits = array_ops.gather(prev_splits * multiples[axis + 1], splits)
294
295    # Repeat each element in this ragged dimension `multiples[axis]` times.
296    if const_multiples is None or const_multiples[axis] != 1:
297      inner_value_ids = ragged_util.repeat_ranges(inner_value_ids, splits,
298                                                  multiples[axis])
299
300    prev_splits = splits
301
302  # Gather the tiled inner values.
303  ragged_tiled_values = array_ops.gather(rt_input.flat_values, inner_value_ids)
304
305  # Tile the flat_values for the uniform dimensions (i.e., for `axis=0` plus
306  # `axis=range(ragged_rank, rank)`).
307  inner_repeats = array_ops.concat([multiples[:1], multiples[ragged_rank + 1:]],
308                                   axis=0)
309  return array_ops.tile(ragged_tiled_values, inner_repeats)
310
311
312def _tile_ragged_splits(rt_input, multiples, const_multiples=None):
313  """Builds nested_split tensors for a tiled `RaggedTensor`.
314
315  Returns a list of split tensors that can be used to construct the
316  `RaggedTensor` that tiles `rt_input` as specified by `multiples`.
317
318  Args:
319    rt_input: The `RaggedTensor` that is being tiled.
320    multiples: A 1-D integer `tensor`, indicating how many times each dimension
321      should be repeated.
322    const_multiples: Optional constant value for multiples.  Used to skip tiling
323      dimensions where `multiples=1`.
324
325  Returns:
326    A list of 1-D integer `Tensor`s (one for each ragged dimension in
327    `rt_input`).
328
329  #### Example:
330
331  >>> rt = tf.ragged.constant([[1, 2], [3]])
332  >>> _tile_ragged_splits(rt, [3, 2])
333  [<tf.Tensor: shape=(7,), dtype=int64,
334  numpy=array([ 0,  4,  6, 10, 12, 16, 18])>]
335  """
336  ragged_rank = rt_input.ragged_rank
337  nested_splits = rt_input.nested_row_splits
338
339  # projected_splits[src_axis, dst_axis] contains the split points that divide
340  # the rows from src_axis in the list of dst_axis values.  E.g.,
341  # projected_splits[i, i] = nested_splits[i], and
342  # projected_splits[i, i+1] = gather(nested_splits[i+1], nested_splits[i]).
343  projected_splits = [{i: nested_splits[i]} for i in range(ragged_rank)]
344  for src_axis in range(ragged_rank):
345    for dst_axis in range(src_axis + 1, ragged_rank - 1):
346      projected_splits[src_axis][dst_axis] = array_ops.gather(
347          nested_splits[dst_axis], projected_splits[src_axis][dst_axis - 1])
348
349  # For each ragged dimension: nested_splits[axis] -> result_splits[axis].
350  result_splits = []
351  for axis in range(ragged_rank):
352    # Get the length of each row for the input tensor for this dimension.
353    input_lengths = nested_splits[axis][1:] - nested_splits[axis][:-1]
354
355    # Multiply those lengths by the `multiples` of dimension axis+1, since
356    # each value will be repeated that number of times.
357    output_lengths = input_lengths * multiples[axis + 1]
358
359    # Repeat ranges of the row lengths as necessary for them to be tiled in
360    # each ragged dimension `d < axis`.  (Start with dimension d=axis-1, and
361    # work our way up to dimension d=0.)
362    repeats = 1
363    for d in range(axis - 1, -1, -1):
364      if const_multiples is None or const_multiples[d + 1] != 1:
365        splits = projected_splits[d][axis - 1] * repeats
366        output_lengths = ragged_util.repeat_ranges(output_lengths, splits,
367                                                   multiples[d + 1])
368      repeats *= multiples[d + 1]
369
370    # Tile splits for the outermost (uniform) dimension.
371    output_lengths = array_ops.tile(output_lengths, multiples[:1])
372
373    # Convert to splits.
374    result_splits.append(ragged_util.lengths_to_splits(output_lengths))
375
376  return result_splits
377
378
379#===============================================================================
380# Reshaping
381#===============================================================================
382
383
384def expand_dims(input, axis, name=None):  # pylint: disable=redefined-builtin
385  """Inserts a dimension with shape 1 into a potentially ragged tensor's shape.
386
387  Given a potentially ragged tenor `input`, this operation inserts a
388  dimension with size 1 at the dimension `axis` of `input`'s shape.
389
390  The following table gives some examples showing how `ragged.expand_dims`
391  impacts the shapes of different input tensors.  Ragged dimensions are
392  indicated by enclosing them in parentheses.
393
394  input.shape             | axis | result.shape
395  ----------------------- | ---- | -----------------------------
396  `[D1, D2]`              |  `0` | `[1, D1, D2]`
397  `[D1, D2]`              |  `1` | `[D1, 1, D2]`
398  `[D1, D2]`              |  `2` | `[D1, D2, 1]`
399  `[D1, (D2), (D3), D4]`  |  `0` | `[1, D1, (D2), (D3), D4]`
400  `[D1, (D2), (D3), D4]`  |  `1` | `[D1, 1, (D2), (D3), D4]`
401  `[D1, (D2), (D3), D4]`  |  `2` | `[D1, (D2), 1, (D3), D4]`
402  `[D1, (D2), (D3), D4]`  |  `3` | `[D1, (D2), (D3), 1, D4]`
403  `[D1, (D2), (D3), D4]`  |  `4` | `[D1, (D2), (D3), D4, 1]`
404
405  Args:
406    input: The potentially tensor that should be expanded with a new dimension.
407    axis: An integer constant indicating where the new dimension should be
408      inserted.
409    name: A name for the operation (optional).
410
411  Returns:
412    A tensor with the same values as `input`, with an added dimension of
413    size 1 at `axis`.
414
415  #### Examples:
416
417  >>> rt = tf.ragged.constant([[1, 2], [3]])
418  >>> print(rt.shape)
419  (2, None)
420
421  >>> expanded = tf.expand_dims(rt, axis=0)
422  >>> print(expanded.shape, expanded)
423  (1, 2, None) <tf.RaggedTensor [[[1, 2], [3]]]>
424
425  >>> expanded = tf.expand_dims(rt, axis=1)
426  >>> print(expanded.shape, expanded)
427  (2, 1, None) <tf.RaggedTensor [[[1, 2]], [[3]]]>
428
429  >>> expanded = tf.expand_dims(rt, axis=2)
430  >>> print(expanded.shape, expanded)
431  (2, None, 1) <tf.RaggedTensor [[[1], [2]], [[3]]]>
432  """
433  with ops.name_scope(name, 'RaggedExpandDims', [input]):
434    input = ragged_tensor.convert_to_tensor_or_ragged_tensor(
435        input, name='input')
436
437    if not ragged_tensor.is_ragged(input):
438      return array_ops.expand_dims(input, axis)
439
440    ndims = None if input.shape.ndims is None else input.shape.ndims + 1
441    axis = array_ops.get_positive_axis(axis, ndims, ndims_name='rank(input)')
442
443    if axis == 0:
444      return ragged_tensor.RaggedTensor.from_uniform_row_length(
445          input, uniform_row_length=input.nrows(), nrows=1, validate=False)
446    elif axis == 1:
447      return ragged_tensor.RaggedTensor.from_uniform_row_length(
448          input, uniform_row_length=1, nrows=input.nrows(), validate=False)
449    else:
450      if ragged_tensor.is_ragged(input.values):
451        return input.with_values(expand_dims(input.values, axis - 1))
452      else:
453        return input.with_values(array_ops.expand_dims(input.values, axis - 1))
454
455
456#===============================================================================
457# RaggedTensor Size
458#===============================================================================
459
460
461def size(input, out_type=dtypes.int32, name=None):  # pylint: disable=redefined-builtin
462  """Returns the size of a potentially ragged tensor.
463
464  The size of a ragged tensor is the size of its inner values.
465
466  #### Example:
467
468  >>> tf.size(tf.ragged.constant([[1, 2], [3]])).numpy()
469  3
470
471  Args:
472    input: A potentially ragged `Tensor`.
473    out_type: The numeric output type for the operation.
474    name: A name for the operation (optional).
475
476  Returns:
477    A Tensor of type `out_type`.
478  """
479  if ragged_tensor.is_ragged(input):
480    return array_ops.size(input.flat_values, out_type=out_type, name=name)
481  else:
482    return array_ops.size(input, out_type=out_type, name=name)
483
484
485#===============================================================================
486# ragged.rank
487#===============================================================================
488def rank(input, name=None):  # pylint: disable=redefined-builtin
489  """Returns the rank of a RaggedTensor.
490
491  Returns a 0-D `int32` `Tensor` representing the rank of `input`.
492
493  #### Example:
494
495  >>> # shape of tensor 't' is [2, None, None]
496  >>> t = tf.ragged.constant([[[1], [2, 2]], [[3, 3, 3], [4, 4, 4, 4]]])
497  >>> tf.rank(t).numpy()
498  3
499
500  Args:
501    input: A `RaggedTensor`
502    name: A name for the operation (optional).
503
504  Returns:
505    A `Tensor` of type `int32`.
506  """
507  with ops.name_scope(name, 'RaggedRank', [input]) as name:
508    if not ragged_tensor.is_ragged(input):
509      return array_ops.rank(input, name)
510
511    return input.ragged_rank + array_ops.rank(input.flat_values)
512
513
514#===============================================================================
515# ragged.one_hot
516#===============================================================================
517def ragged_one_hot(indices,
518                   depth,
519                   on_value=None,
520                   off_value=None,
521                   axis=None,
522                   dtype=None,
523                   name=None):
524  """Applies tf.one_hot along the values of a RaggedTensor."""
525  # Get the adjusted axis value for the call to array_ops.one_hot.
526  # Note: the only negative `axis` value supported by array_ops.one_hot is -1.
527  if isinstance(axis, int) and axis >= 0:
528    if axis <= indices.ragged_rank:
529      raise ValueError('axis (%d) must be greater than indices.ragged_rank '
530                       '(%d).' % (axis, indices.ragged_rank))
531    axis -= indices.ragged_rank
532
533  with ops.name_scope(name, 'RaggedOneHot',
534                      [indices, depth, on_value, off_value, axis]):
535    indices = ragged_tensor.convert_to_tensor_or_ragged_tensor(
536        indices, name='indices')
537    return indices.with_flat_values(
538        array_ops.one_hot(indices.flat_values, depth, on_value, off_value, axis,
539                          dtype, name))
540
541
542#===============================================================================
543# ragged.stack_dynamic_partitions
544#===============================================================================
545@tf_export('ragged.stack_dynamic_partitions')
546@dispatch.add_dispatch_support
547def stack_dynamic_partitions(data, partitions, num_partitions, name=None):
548  """Stacks dynamic partitions of a Tensor or RaggedTensor.
549
550  Returns a RaggedTensor `output` with `num_partitions` rows, where the row
551  `output[i]` is formed by stacking all slices `data[j1...jN]` such that
552  `partitions[j1...jN] = i`.  Slices of `data` are stacked in row-major
553  order.
554
555  If `num_partitions` is an `int` (not a `Tensor`), then this is equivalent to
556  `tf.ragged.stack(tf.dynamic_partition(data, partitions, num_partitions))`.
557
558  #### Example:
559
560  >>> data           = ['a', 'b', 'c', 'd', 'e']
561  >>> partitions     = [  3,   0,   2,   2,   3]
562  >>> num_partitions = 5
563  >>> tf.ragged.stack_dynamic_partitions(data, partitions, num_partitions)
564  <tf.RaggedTensor [[b'b'], [], [b'c', b'd'], [b'a', b'e'], []]>
565
566  Args:
567    data: A `Tensor` or `RaggedTensor` containing the values to stack.
568    partitions: An `int32` or `int64` `Tensor` or `RaggedTensor` specifying the
569      partition that each slice of `data` should be added to. `partitions.shape`
570      must be a prefix of `data.shape`.  Values must be greater than or equal to
571      zero, and less than `num_partitions`. `partitions` is not required to be
572      sorted.
573    num_partitions: An `int32` or `int64` scalar specifying the number of
574      partitions to output.  This determines the number of rows in `output`.
575    name: A name prefix for the returned tensor (optional).
576
577  Returns:
578    A `RaggedTensor` containing the stacked partitions.  The returned tensor
579    has the same dtype as `data`, and its shape is
580    `[num_partitions, (D)] + data.shape[partitions.rank:]`, where `(D)` is a
581    ragged dimension whose length is the number of data slices stacked for
582    each `partition`.
583  """
584  with ops.name_scope(name, 'SegmentStack', [data, partitions, num_partitions]):
585    # Convert inputs to tensors.
586    data = ragged_tensor.convert_to_tensor_or_ragged_tensor(data, name='data')
587    row_splits_dtype = (
588        data.row_splits.dtype
589        if isinstance(data, ragged_tensor.RaggedTensor) else None)
590    partitions = ragged_tensor.convert_to_tensor_or_ragged_tensor(
591        partitions, name='partitions', preferred_dtype=row_splits_dtype)
592    num_partitions = ops.convert_to_tensor(
593        num_partitions, name='num_partitions', preferred_dtype=partitions.dtype)
594    if row_splits_dtype is not None:
595      partitions = math_ops.cast(partitions, row_splits_dtype)
596    num_partitions = math_ops.cast(num_partitions, partitions.dtype)
597
598    # Sanity-checks for shapes.
599    partitions_rank = partitions.shape.ndims
600    if partitions_rank is None:
601      raise ValueError('partitions must have known rank.')
602    num_partitions.shape.assert_has_rank(0)
603    partitions.shape.assert_is_compatible_with(data.shape[:partitions_rank])
604
605    if partitions_rank == 0:
606      # If partitions is a scalar, then just create a RaggedTensor containing
607      # that single the complete `data` value in the specified row.
608      return ragged_tensor.RaggedTensor.from_value_rowids(
609          values=array_ops.stack([data]),
610          value_rowids=array_ops.stack([partitions]),
611          nrows=num_partitions,
612          validate=False)
613
614    elif partitions_rank == 1:
615      # If partitions is a vector (the typical case): we can just use data and
616      # partitions as the `values` and `value_rowids` for `from_value_rowids`,
617      # as long as we sort them first.
618      permutation = sort_ops.argsort(partitions, stable=True)
619      value_rowids = array_ops.gather(partitions, permutation)
620      values = array_ops.gather(data, permutation)
621      check = check_ops.assert_less(
622          value_rowids[-1:],
623          num_partitions,
624          message='partitions must be less than num_partitions')
625      with ops.control_dependencies([check]):
626        return ragged_tensor.RaggedTensor.from_value_rowids(
627            values, value_rowids, nrows=num_partitions, validate=False)
628
629    else:
630      # Handle higher-dimensional partitions via recursion.
631      if not isinstance(data, ragged_tensor.RaggedTensor):
632        data = ragged_tensor.RaggedTensor.from_tensor(
633            data, row_splits_dtype=partitions.dtype, ragged_rank=1)
634      if not isinstance(partitions, ragged_tensor.RaggedTensor):
635        partitions = ragged_tensor.RaggedTensor.from_tensor(
636            partitions,
637            row_splits_dtype=partitions.dtype,
638            ragged_rank=max(data.ragged_rank, partitions_rank - 1))
639      check = check_ops.assert_equal(
640          data.row_splits,
641          partitions.row_splits,
642          message='data and partitions have incompatible ragged shapes')
643      with ops.control_dependencies([check]):
644        return stack_dynamic_partitions(data.values, partitions.values,
645                                        num_partitions)
646
647
648#===============================================================================
649# Reverse
650#===============================================================================
651def reverse(tensor, axis, name=None):
652  """Reverses a RaggedTensor along the specified axes.
653
654  #### Example:
655
656  >>> data = tf.ragged.constant([
657  ...   [[1, 2], [3, 4]], [[5, 6]], [[7, 8], [9, 10], [11, 12]]])
658  >>> tf.reverse(data, axis=[0, 2])
659  <tf.RaggedTensor [[[8, 7], [10, 9], [12, 11]], [[6, 5]], [[2, 1], [4, 3]]]>
660
661  Args:
662    tensor: A 'RaggedTensor' to reverse.
663    axis: A list or tuple of 'int' or a constant 1D 'tf.Tensor'. The indices of
664      the axes to reverse.
665    name: A name prefix for the returned tensor (optional).
666
667  Returns:
668    A 'RaggedTensor'.
669  """
670  type_error_msg = ('`axis` must be a list of int or a constant tensor'
671                    'when reversing axes in a ragged tensor')
672
673  with ops.name_scope(name, 'Reverse', [tensor, axis]):
674    if isinstance(axis, ops.Tensor):
675      axis = tensor_util.constant_value(axis)
676      if axis is None:
677        raise TypeError(type_error_msg)
678    elif not (isinstance(axis, (list, tuple)) and
679              all(isinstance(dim, int) for dim in axis)):
680      raise TypeError(type_error_msg)
681
682    tensor = ragged_tensor.convert_to_tensor_or_ragged_tensor(
683        tensor, name='tensor')
684
685    # Allow usage of negative values to specify innermost axes.
686    axis = [
687        array_ops.get_positive_axis(dim, tensor.shape.rank, 'axis[%d]' % i,
688                                    'rank(tensor)')
689        for i, dim in enumerate(axis)
690    ]
691
692    # We only need to slice up to the max axis. If the axis list
693    # is empty, it should be 0.
694    slices = [slice(None)] * (max(axis) + 1 if axis else 0)
695
696    for dim in axis:
697      slices[dim] = slice(None, None, -1)
698
699    return tensor[tuple(slices)]
700
701
702#===============================================================================
703# Cross
704#===============================================================================
705
706
707@tf_export('ragged.cross')
708@dispatch.add_dispatch_support
709def cross(inputs, name=None):
710  """Generates feature cross from a list of tensors.
711
712  The input tensors must have `rank=2`, and must all have the same number of
713  rows.  The result is a `RaggedTensor` with the same number of rows as the
714  inputs, where `result[row]` contains a list of all combinations of values
715  formed by taking a single value from each input's corresponding row
716  (`inputs[i][row]`).  Values are combined by joining their strings with '_X_'.
717  E.g.:
718
719  >>> tf.ragged.cross([tf.ragged.constant([['a'], ['b', 'c']]),
720  ...                  tf.ragged.constant([['d'], ['e']]),
721  ...                  tf.ragged.constant([['f'], ['g']])])
722  <tf.RaggedTensor [[b'a_X_d_X_f'], [b'b_X_e_X_g', b'c_X_e_X_g']]>
723
724  Args:
725    inputs: A list of `RaggedTensor` or `Tensor` or `SparseTensor`.
726    name: Optional name for the op.
727
728  Returns:
729    A 2D `RaggedTensor` of type `string`.
730  """
731  return _cross_internal(inputs=inputs, hashed_output=False, name=name)
732
733
734@tf_export('ragged.cross_hashed')
735@dispatch.add_dispatch_support
736def cross_hashed(inputs, num_buckets=0, hash_key=None, name=None):
737  """Generates hashed feature cross from a list of tensors.
738
739  The input tensors must have `rank=2`, and must all have the same number of
740  rows.  The result is a `RaggedTensor` with the same number of rows as the
741  inputs, where `result[row]` contains a list of all combinations of values
742  formed by taking a single value from each input's corresponding row
743  (`inputs[i][row]`).  Values are combined by hashing together their
744  fingerprints. E.g.:
745
746  >>> tf.ragged.cross_hashed([tf.ragged.constant([['a'], ['b', 'c']]),
747  ...                         tf.ragged.constant([['d'], ['e']]),
748  ...                         tf.ragged.constant([['f'], ['g']])],
749  ...                        num_buckets=100)
750  <tf.RaggedTensor [[78], [66, 74]]>
751
752  Args:
753    inputs: A list of `RaggedTensor` or `Tensor` or `SparseTensor`.
754    num_buckets: A non-negative `int` that used to bucket the hashed values. If
755      `num_buckets != 0`, then `output = hashed_value % num_buckets`.
756    hash_key: Integer hash_key that will be used by the `FingerprintCat64`
757      function. If not given, a default key is used.
758    name: Optional name for the op.
759
760  Returns:
761    A 2D `RaggedTensor` of type `int64`.
762  """
763  return _cross_internal(
764      inputs=inputs,
765      hashed_output=True,
766      num_buckets=num_buckets,
767      hash_key=hash_key,
768      name=name)
769
770
771_DEFAULT_CROSS_HASH_KEY = 0xDECAFCAFFE
772
773
774def _cross_internal(inputs,
775                    hashed_output=False,
776                    num_buckets=0,
777                    hash_key=None,
778                    name=None):
779  """Generates feature cross from a list of ragged and dense tensors."""
780  if not isinstance(inputs, (tuple, list)):
781    raise TypeError('Inputs must be a list')
782
783  if hash_key is None:
784    hash_key = _DEFAULT_CROSS_HASH_KEY
785
786  ragged_inputs = []
787  sparse_inputs = []
788  dense_inputs = []
789  input_order = []
790  with ops.name_scope(name, 'RaggedCross', inputs):
791    for i, t in enumerate(inputs):
792      if sparse_tensor.is_sparse(t):
793        t = sparse_tensor.SparseTensor.from_value(t)
794      else:
795        t = ragged_tensor.convert_to_tensor_or_ragged_tensor(t)
796      if t.dtype.is_integer:
797        t = math_ops.cast(t, dtypes.int64)
798      elif t.dtype != dtypes.string:
799        raise ValueError('Unexpected dtype for inputs[%d]: %s' % (i, t.dtype))
800      if isinstance(t, ragged_tensor.RaggedTensor):
801        if t.ragged_rank != 1:
802          raise ValueError('tf.ragged.cross only supports inputs with rank=2')
803        ragged_inputs.append(t)
804        input_order.append('R')
805      elif isinstance(t, sparse_tensor.SparseTensor):
806        sparse_inputs.append(t)
807        input_order.append('S')
808      else:
809        dense_inputs.append(t)
810        input_order.append('D')
811
812    out_values_type = dtypes.int64 if hashed_output else dtypes.string
813    if ragged_inputs and all(
814        t.row_splits.dtype == dtypes.int32 for t in ragged_inputs):
815      out_row_splits_type = dtypes.int32
816    else:
817      out_row_splits_type = dtypes.int64
818
819    # Convert hash_key from uint64 -> int64, since we need to pass it via
820    # an int64 attr.
821    if hash_key > 2**63:
822      hash_key -= 2**64
823
824    values_out, splits_out = gen_ragged_array_ops.ragged_cross(
825        ragged_values=[rt.values for rt in ragged_inputs],
826        ragged_row_splits=[rt.row_splits for rt in ragged_inputs],
827        sparse_indices=[st.indices for st in sparse_inputs],
828        sparse_values=[st.values for st in sparse_inputs],
829        sparse_shape=[st.dense_shape for st in sparse_inputs],
830        dense_inputs=dense_inputs,
831        input_order=''.join(input_order),
832        hashed_output=hashed_output,
833        num_buckets=num_buckets,
834        hash_key=hash_key,
835        out_values_type=out_values_type.as_datatype_enum,
836        out_row_splits_type=out_row_splits_type.as_datatype_enum,
837        name=name)
838
839    return ragged_tensor.RaggedTensor.from_row_splits(
840        values_out, splits_out, validate=False)
841