• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Classes for storing ragged tensors and their values."""
16
17import functools
18import operator
19
20import typing
21import numpy as np
22
23from tensorflow.python import tf2
24from tensorflow.python.client import session
25from tensorflow.python.framework import composite_tensor
26from tensorflow.python.framework import composite_tensor_gradient
27from tensorflow.python.framework import constant_op
28from tensorflow.python.framework import dtypes
29from tensorflow.python.framework import ops
30from tensorflow.python.framework import sparse_tensor
31from tensorflow.python.framework import tensor_shape
32from tensorflow.python.framework import tensor_spec
33from tensorflow.python.framework import tensor_util
34from tensorflow.python.framework import type_spec
35from tensorflow.python.ops import array_ops
36from tensorflow.python.ops import check_ops
37from tensorflow.python.ops import control_flow_ops
38from tensorflow.python.ops import gen_ragged_conversion_ops
39from tensorflow.python.ops import math_ops
40from tensorflow.python.ops.ragged import ragged_config
41from tensorflow.python.ops.ragged import ragged_tensor_value
42from tensorflow.python.ops.ragged import ragged_util
43from tensorflow.python.ops.ragged.row_partition import RowPartition
44from tensorflow.python.types import core as core_types
45from tensorflow.python.types import internal as internal_types
46from tensorflow.python.util import dispatch
47from tensorflow.python.util.tf_export import tf_export
48from tensorflow.tools.docs import doc_controls
49
50# pylint: disable=protected-access
51_convert_row_partition = RowPartition._convert_row_partition
52# pylint: enable=protected-access
53
54#===============================================================================
55# RaggedTensor
56#===============================================================================
57
58
59@tf_export("RaggedTensor")
60class RaggedTensor(composite_tensor.CompositeTensor,
61                   internal_types.NativeObject):
62  """Represents a ragged tensor.
63
64  A `RaggedTensor` is a tensor with one or more *ragged dimensions*, which are
65  dimensions whose slices may have different lengths.  For example, the inner
66  (column) dimension of `rt=[[3, 1, 4, 1], [], [5, 9, 2], [6], []]` is ragged,
67  since the column slices (`rt[0, :]`, ..., `rt[4, :]`) have different lengths.
68  Dimensions whose slices all have the same length are called *uniform
69  dimensions*.  The outermost dimension of a `RaggedTensor` is always uniform,
70  since it consists of a single slice (and so there is no possibility for
71  differing slice lengths).
72
73  The total number of dimensions in a `RaggedTensor` is called its *rank*,
74  and the number of ragged dimensions in a `RaggedTensor` is called its
75  *ragged-rank*.  A `RaggedTensor`'s ragged-rank is fixed at graph creation
76  time: it can't depend on the runtime values of `Tensor`s, and can't vary
77  dynamically for different session runs.
78
79  Note that the `__init__` constructor is private. Please use one of the
80  following methods to construct a `RaggedTensor`:
81
82  * `tf.RaggedTensor.from_row_lengths`
83  * `tf.RaggedTensor.from_value_rowids`
84  * `tf.RaggedTensor.from_row_splits`
85  * `tf.RaggedTensor.from_row_starts`
86  * `tf.RaggedTensor.from_row_limits`
87  * `tf.RaggedTensor.from_nested_row_splits`
88  * `tf.RaggedTensor.from_nested_row_lengths`
89  * `tf.RaggedTensor.from_nested_value_rowids`
90
91  ### Potentially Ragged Tensors
92
93  Many ops support both `Tensor`s and `RaggedTensor`s
94  (see [tf.ragged](https://www.tensorflow.org/api_docs/python/tf/ragged) for a
95  full listing). The term "potentially ragged tensor" may be used to refer to a
96  tensor that might be either a `Tensor` or a `RaggedTensor`.  The ragged-rank
97  of a `Tensor` is zero.
98
99  ### Documenting RaggedTensor Shapes
100
101  When documenting the shape of a RaggedTensor, ragged dimensions can be
102  indicated by enclosing them in parentheses.  For example, the shape of
103  a 3-D `RaggedTensor` that stores the fixed-size word embedding for each
104  word in a sentence, for each sentence in a batch, could be written as
105  `[num_sentences, (num_words), embedding_size]`.  The parentheses around
106  `(num_words)` indicate that dimension is ragged, and that the length
107  of each element list in that dimension may vary for each item.
108
109  ### Component Tensors
110
111  Internally, a `RaggedTensor` consists of a concatenated list of values that
112  are partitioned into variable-length rows.  In particular, each `RaggedTensor`
113  consists of:
114
115    * A `values` tensor, which concatenates the variable-length rows into a
116      flattened list.  For example, the `values` tensor for
117      `[[3, 1, 4, 1], [], [5, 9, 2], [6], []]` is `[3, 1, 4, 1, 5, 9, 2, 6]`.
118
119    * A `row_splits` vector, which indicates how those flattened values are
120      divided into rows.  In particular, the values for row `rt[i]` are stored
121      in the slice `rt.values[rt.row_splits[i]:rt.row_splits[i+1]]`.
122
123  Example:
124
125  >>> print(tf.RaggedTensor.from_row_splits(
126  ...       values=[3, 1, 4, 1, 5, 9, 2, 6],
127  ...       row_splits=[0, 4, 4, 7, 8, 8]))
128  <tf.RaggedTensor [[3, 1, 4, 1], [], [5, 9, 2], [6], []]>
129
130  ### Alternative Row-Partitioning Schemes
131
132  In addition to `row_splits`, ragged tensors provide support for five other
133  row-partitioning schemes:
134
135    * `row_lengths`: a vector with shape `[nrows]`, which specifies the length
136      of each row.
137
138    * `value_rowids` and `nrows`: `value_rowids` is a vector with shape
139      `[nvals]`, corresponding one-to-one with `values`, which specifies
140      each value's row index.  In particular, the row `rt[row]` consists of the
141      values `rt.values[j]` where `value_rowids[j]==row`.  `nrows` is an
142      integer scalar that specifies the number of rows in the
143      `RaggedTensor`. (`nrows` is used to indicate trailing empty rows.)
144
145    * `row_starts`: a vector with shape `[nrows]`, which specifies the start
146      offset of each row.  Equivalent to `row_splits[:-1]`.
147
148    * `row_limits`: a vector with shape `[nrows]`, which specifies the stop
149      offset of each row.  Equivalent to `row_splits[1:]`.
150
151    * `uniform_row_length`: A scalar tensor, specifying the length of every
152      row.  This row-partitioning scheme may only be used if all rows have
153      the same length.
154
155  Example: The following ragged tensors are equivalent, and all represent the
156  nested list `[[3, 1, 4, 1], [], [5, 9, 2], [6], []]`.
157
158  >>> values = [3, 1, 4, 1, 5, 9, 2, 6]
159  >>> RaggedTensor.from_row_splits(values, row_splits=[0, 4, 4, 7, 8, 8])
160  <tf.RaggedTensor [[3, 1, 4, 1], [], [5, 9, 2], [6], []]>
161  >>> RaggedTensor.from_row_lengths(values, row_lengths=[4, 0, 3, 1, 0])
162  <tf.RaggedTensor [[3, 1, 4, 1], [], [5, 9, 2], [6], []]>
163  >>> RaggedTensor.from_value_rowids(
164  ...     values, value_rowids=[0, 0, 0, 0, 2, 2, 2, 3], nrows=5)
165  <tf.RaggedTensor [[3, 1, 4, 1], [], [5, 9, 2], [6], []]>
166  >>> RaggedTensor.from_row_starts(values, row_starts=[0, 4, 4, 7, 8])
167  <tf.RaggedTensor [[3, 1, 4, 1], [], [5, 9, 2], [6], []]>
168  >>> RaggedTensor.from_row_limits(values, row_limits=[4, 4, 7, 8, 8])
169  <tf.RaggedTensor [[3, 1, 4, 1], [], [5, 9, 2], [6], []]>
170  >>> RaggedTensor.from_uniform_row_length(values, uniform_row_length=2)
171  <tf.RaggedTensor [[3, 1], [4, 1], [5, 9], [2, 6]]>
172
173  ### Multiple Ragged Dimensions
174
175  `RaggedTensor`s with multiple ragged dimensions can be defined by using
176  a nested `RaggedTensor` for the `values` tensor.  Each nested `RaggedTensor`
177  adds a single ragged dimension.
178
179  >>> inner_rt = RaggedTensor.from_row_splits(  # =rt1 from above
180  ...     values=[3, 1, 4, 1, 5, 9, 2, 6], row_splits=[0, 4, 4, 7, 8, 8])
181  >>> outer_rt = RaggedTensor.from_row_splits(
182  ...     values=inner_rt, row_splits=[0, 3, 3, 5])
183  >>> print(outer_rt.to_list())
184  [[[3, 1, 4, 1], [], [5, 9, 2]], [], [[6], []]]
185  >>> print(outer_rt.ragged_rank)
186  2
187
188  The factory function `RaggedTensor.from_nested_row_splits` may be used to
189  construct a `RaggedTensor` with multiple ragged dimensions directly, by
190  providing a list of `row_splits` tensors:
191
192  >>> RaggedTensor.from_nested_row_splits(
193  ...     flat_values=[3, 1, 4, 1, 5, 9, 2, 6],
194  ...     nested_row_splits=([0, 3, 3, 5], [0, 4, 4, 7, 8, 8])).to_list()
195  [[[3, 1, 4, 1], [], [5, 9, 2]], [], [[6], []]]
196
197  ### Uniform Inner Dimensions
198
199  `RaggedTensor`s with uniform inner dimensions can be defined
200  by using a multidimensional `Tensor` for `values`.
201
202  >>> rt = RaggedTensor.from_row_splits(values=tf.ones([5, 3], tf.int32),
203  ...                                   row_splits=[0, 2, 5])
204  >>> print(rt.to_list())
205  [[[1, 1, 1], [1, 1, 1]],
206   [[1, 1, 1], [1, 1, 1], [1, 1, 1]]]
207  >>> print(rt.shape)
208  (2, None, 3)
209
210  ### Uniform Outer Dimensions
211
212  `RaggedTensor`s with uniform outer dimensions can be defined by using
213  one or more `RaggedTensor` with a `uniform_row_length` row-partitioning
214  tensor.  For example, a `RaggedTensor` with shape `[2, 2, None]` can be
215  constructed with this method from a `RaggedTensor` values with shape
216  `[4, None]`:
217
218  >>> values = tf.ragged.constant([[1, 2, 3], [4], [5, 6], [7, 8, 9, 10]])
219  >>> print(values.shape)
220  (4, None)
221  >>> rt6 = tf.RaggedTensor.from_uniform_row_length(values, 2)
222  >>> print(rt6)
223  <tf.RaggedTensor [[[1, 2, 3], [4]], [[5, 6], [7, 8, 9, 10]]]>
224  >>> print(rt6.shape)
225  (2, 2, None)
226
227  Note that `rt6` only contains one ragged dimension (the innermost
228  dimension). In contrast, if `from_row_splits` is used to construct a similar
229  `RaggedTensor`, then that `RaggedTensor` will have two ragged dimensions:
230
231  >>> rt7 = tf.RaggedTensor.from_row_splits(values, [0, 2, 4])
232  >>> print(rt7.shape)
233  (2, None, None)
234
235  Uniform and ragged outer dimensions may be interleaved, meaning that a
236  tensor with any combination of ragged and uniform dimensions may be created.
237  For example, a RaggedTensor `t4` with shape `[3, None, 4, 8, None, 2]` could
238  be constructed as follows:
239
240  ```python
241  t0 = tf.zeros([1000, 2])                           # Shape:         [1000, 2]
242  t1 = RaggedTensor.from_row_lengths(t0, [...])      #           [160, None, 2]
243  t2 = RaggedTensor.from_uniform_row_length(t1, 8)   #         [20, 8, None, 2]
244  t3 = RaggedTensor.from_uniform_row_length(t2, 4)   #       [5, 4, 8, None, 2]
245  t4 = RaggedTensor.from_row_lengths(t3, [...])      # [3, None, 4, 8, None, 2]
246  ```
247
248  """
249
250  #=============================================================================
251  # Constructor (private)
252  #=============================================================================
253  @doc_controls.do_not_generate_docs
254  def __init__(self, values, row_partition, internal=False):
255    """Creates a `RaggedTensor` with a specified partitioning for `values`.
256
257    This constructor is private -- please use one of the following ops to
258    build `RaggedTensor`s:
259
260      * `tf.RaggedTensor.from_row_lengths`
261      * `tf.RaggedTensor.from_value_rowids`
262      * `tf.RaggedTensor.from_row_splits`
263      * `tf.RaggedTensor.from_row_starts`
264      * `tf.RaggedTensor.from_row_limits`
265      * `tf.RaggedTensor.from_nested_row_splits`
266      * `tf.RaggedTensor.from_nested_row_lengths`
267      * `tf.RaggedTensor.from_nested_value_rowids`
268
269    Args:
270      values: A potentially ragged tensor of any dtype and shape `[nvals, ...]`.
271      row_partition: A `RowPartition` object, representing the arrangement of
272        the lists at the top level.
273      internal: True if the constructor is being called by one of the factory
274        methods.  If false, an exception will be raised.
275
276    Raises:
277      ValueError: If internal = False. Note that this method is intended only
278                 for internal use.
279      TypeError: If values is not a `RaggedTensor` or `Tensor`, or
280                 row_partition is not a `RowPartition`.
281    """
282
283    if not internal:
284      raise ValueError("RaggedTensor constructor is private; please use one "
285                       "of the factory methods instead (e.g., "
286                       "RaggedTensor.from_row_lengths())")
287    _assert_is_supported_ragged_values_type(values)
288    if not isinstance(row_partition, RowPartition):
289      raise TypeError(f"Argument `row_partition` must be a RowPartition. "
290                      f"Received {row_partition}.")
291
292    # Validate shapes.
293    values.shape.with_rank_at_least(1)
294    if isinstance(values, RaggedTensor):
295      # pylint: disable=protected-access
296      assert row_partition.dtype == values._row_partition.dtype
297
298    self._values = values
299    self._row_partition = row_partition
300
301  #=============================================================================
302  # Factory Methods
303  #=============================================================================
304
305  @classmethod
306  def _from_row_partition(cls, values, row_partition, validate=True):
307    """Creates a `RaggedTensor` with a row partition.
308
309    This is used as a way for RaggedTensors to share row partitions.
310
311    The outer dimension of values must be equal to `partition.nvals()`.
312
313    Args:
314      values: A potentially ragged tensor.
315      row_partition: a `RowPartition`: can be shared between tensors.
316      validate: If true, then use assertions to check that the arguments form a
317        valid `RaggedTensor`.
318
319    Returns:
320      A `RaggedTensor`.  `result.rank = values.rank + 1`.
321      `result.ragged_rank = values.ragged_rank + 1`.
322
323    Raises:
324      ValueError: If partition.nvals() != _nrows(values)
325    """
326    if not isinstance(row_partition, RowPartition):
327      raise TypeError(f"Argument `row_partition` must be a RowPartition. "
328                      f"Received {row_partition}.")
329    if not isinstance(validate, bool):
330      raise TypeError(f"Argument `validate` must have type bool. "
331                      f"Received {validate}.")
332    values, row_partition = cls._convert_values_and_partition(
333        values, row_partition, "partition")
334    if row_partition._has_precomputed_value_rowids():  # pylint: disable=protected-access
335      value_rowids_shape = row_partition.value_rowids().shape
336      values.shape[:1].assert_is_compatible_with(value_rowids_shape)
337    if validate:
338      msg = "Arguments to _from_row_partition do not form a valid RaggedTensor"
339      nvals = _nrows(values, row_partition.dtype)
340      checks = [
341          check_ops.assert_equal(
342              math_ops.cast(row_partition.nvals(), row_partition.dtype),
343              nvals,
344              message=msg),
345      ]
346      if not isinstance(values, RaggedTensor):
347        checks.append(check_ops.assert_rank_at_least(values, 1))
348      row_partition = row_partition._with_dependencies(checks)  # pylint: disable=protected-access
349    return cls(values=values, internal=True, row_partition=row_partition)
350
351  @classmethod
352  @dispatch.add_dispatch_support
353  def from_value_rowids(cls,
354                        values,
355                        value_rowids,
356                        nrows=None,
357                        name=None,
358                        validate=True):
359    """Creates a `RaggedTensor` with rows partitioned by `value_rowids`.
360
361    The returned `RaggedTensor` corresponds with the python list defined by:
362
363    ```python
364    result = [[values[i] for i in range(len(values)) if value_rowids[i] == row]
365              for row in range(nrows)]
366    ```
367
368    Args:
369      values: A potentially ragged tensor with shape `[nvals, ...]`.
370      value_rowids: A 1-D integer tensor with shape `[nvals]`, which corresponds
371        one-to-one with `values`, and specifies each value's row index.  Must be
372        nonnegative, and must be sorted in ascending order.
373      nrows: An integer scalar specifying the number of rows.  This should be
374        specified if the `RaggedTensor` may containing empty training rows. Must
375        be greater than `value_rowids[-1]` (or zero if `value_rowids` is empty).
376        Defaults to `value_rowids[-1] + 1` (or zero if `value_rowids` is empty).
377      name: A name prefix for the RaggedTensor (optional).
378      validate: If true, then use assertions to check that the arguments form
379        a valid `RaggedTensor`.  Note: these assertions incur a runtime cost,
380          since they must be checked for each tensor value.
381
382    Returns:
383      A `RaggedTensor`.  `result.rank = values.rank + 1`.
384      `result.ragged_rank = values.ragged_rank + 1`.
385
386    Raises:
387      ValueError: If `nrows` is incompatible with `value_rowids`.
388
389    #### Example:
390
391    >>> print(tf.RaggedTensor.from_value_rowids(
392    ...     values=[3, 1, 4, 1, 5, 9, 2, 6],
393    ...     value_rowids=[0, 0, 0, 0, 2, 2, 2, 3],
394    ...     nrows=5))
395    <tf.RaggedTensor [[3, 1, 4, 1], [], [5, 9, 2], [6], []]>
396
397    """
398    if not isinstance(validate, bool):
399      raise TypeError(f"Argument `validate` must have type bool. "
400                      f"Received {validate}.")
401
402    with ops.name_scope(name, "RaggedFromValueRowIds",
403                        [values, value_rowids, nrows]):
404      row_partition = RowPartition.from_value_rowids(
405          value_rowids=value_rowids,
406          nrows=nrows,
407          validate=validate,
408          dtype_hint=_get_optional_partition_dtype(values))
409      return cls._from_row_partition(values, row_partition, validate=validate)
410
411  @classmethod
412  @dispatch.add_dispatch_support
413  def from_row_splits(cls, values, row_splits, name=None, validate=True):
414    """Creates a `RaggedTensor` with rows partitioned by `row_splits`.
415
416    The returned `RaggedTensor` corresponds with the python list defined by:
417
418    ```python
419    result = [values[row_splits[i]:row_splits[i + 1]]
420              for i in range(len(row_splits) - 1)]
421    ```
422
423    Args:
424      values: A potentially ragged tensor with shape `[nvals, ...]`.
425      row_splits: A 1-D integer tensor with shape `[nrows+1]`.  Must not be
426        empty, and must be sorted in ascending order.  `row_splits[0]` must be
427        zero and `row_splits[-1]` must be `nvals`.
428      name: A name prefix for the RaggedTensor (optional).
429      validate: If true, then use assertions to check that the arguments form
430        a valid `RaggedTensor`.  Note: these assertions incur a runtime cost,
431          since they must be checked for each tensor value.
432
433    Returns:
434      A `RaggedTensor`.  `result.rank = values.rank + 1`.
435      `result.ragged_rank = values.ragged_rank + 1`.
436
437    Raises:
438      ValueError: If `row_splits` is an empty list.
439
440    #### Example:
441
442    >>> print(tf.RaggedTensor.from_row_splits(
443    ...     values=[3, 1, 4, 1, 5, 9, 2, 6],
444    ...     row_splits=[0, 4, 4, 7, 8, 8]))
445    <tf.RaggedTensor [[3, 1, 4, 1], [], [5, 9, 2], [6], []]>
446
447    """
448    if not isinstance(validate, bool):
449      raise TypeError(f"Argument `validate` must have type bool. "
450                      f"Received {validate}.")
451
452    with ops.name_scope(name, "RaggedFromRowSplits", [values, row_splits]):
453      row_partition = RowPartition.from_row_splits(
454          row_splits=row_splits,
455          validate=validate,
456          dtype_hint=_get_optional_partition_dtype(values))
457      return cls._from_row_partition(values, row_partition, validate=validate)
458
459  @classmethod
460  @dispatch.add_dispatch_support
461  def from_row_lengths(cls, values, row_lengths, name=None, validate=True):
462    """Creates a `RaggedTensor` with rows partitioned by `row_lengths`.
463
464    The returned `RaggedTensor` corresponds with the python list defined by:
465
466    ```python
467    result = [[values.pop(0) for i in range(length)]
468              for length in row_lengths]
469    ```
470
471    Args:
472      values: A potentially ragged tensor with shape `[nvals, ...]`.
473      row_lengths: A 1-D integer tensor with shape `[nrows]`.  Must be
474        nonnegative.  `sum(row_lengths)` must be `nvals`.
475      name: A name prefix for the RaggedTensor (optional).
476      validate: If true, then use assertions to check that the arguments form
477        a valid `RaggedTensor`.  Note: these assertions incur a runtime cost,
478          since they must be checked for each tensor value.
479
480    Returns:
481      A `RaggedTensor`.  `result.rank = values.rank + 1`.
482      `result.ragged_rank = values.ragged_rank + 1`.
483
484    #### Example:
485
486    >>> print(tf.RaggedTensor.from_row_lengths(
487    ...     values=[3, 1, 4, 1, 5, 9, 2, 6],
488    ...     row_lengths=[4, 0, 3, 1, 0]))
489    <tf.RaggedTensor [[3, 1, 4, 1], [], [5, 9, 2], [6], []]>
490
491    """
492    if not isinstance(validate, bool):
493      raise TypeError(f"Argument `validate` must have type bool. "
494                      f"Received {validate}.")
495
496    with ops.name_scope(name, "RaggedFromRowLengths", [values, row_lengths]):
497      row_partition = RowPartition.from_row_lengths(
498          row_lengths=row_lengths,
499          validate=validate,
500          dtype_hint=_get_optional_partition_dtype(values))
501      return cls._from_row_partition(values, row_partition, validate=validate)
502
503  @classmethod
504  @dispatch.add_dispatch_support
505  def from_row_starts(cls, values, row_starts, name=None, validate=True):
506    """Creates a `RaggedTensor` with rows partitioned by `row_starts`.
507
508    Equivalent to: `from_row_splits(values, concat([row_starts, nvals]))`.
509
510    Args:
511      values: A potentially ragged tensor with shape `[nvals, ...]`.
512      row_starts: A 1-D integer tensor with shape `[nrows]`.  Must be
513        nonnegative and sorted in ascending order.  If `nrows>0`, then
514        `row_starts[0]` must be zero.
515      name: A name prefix for the RaggedTensor (optional).
516      validate: If true, then use assertions to check that the arguments form
517        a valid `RaggedTensor`.  Note: these assertions incur a runtime cost,
518          since they must be checked for each tensor value.
519
520    Returns:
521      A `RaggedTensor`.  `result.rank = values.rank + 1`.
522      `result.ragged_rank = values.ragged_rank + 1`.
523
524    #### Example:
525
526    >>> print(tf.RaggedTensor.from_row_starts(
527    ...     values=[3, 1, 4, 1, 5, 9, 2, 6],
528    ...     row_starts=[0, 4, 4, 7, 8]))
529    <tf.RaggedTensor [[3, 1, 4, 1], [], [5, 9, 2], [6], []]>
530
531    """
532    if not isinstance(validate, bool):
533      raise TypeError(f"Argument `validate` must have type bool. "
534                      f"Received {validate}.")
535    with ops.name_scope(name, "RaggedFromRowStarts", [values, row_starts]):
536      values = _convert_to_ragged_tensor_values(values)
537      row_partition = RowPartition.from_row_starts(
538          row_starts=row_starts,
539          nvals=_nrows(values),
540          validate=validate,
541          dtype_hint=_get_optional_partition_dtype(values))
542      return cls._from_row_partition(values, row_partition, validate=validate)
543
544  @classmethod
545  @dispatch.add_dispatch_support
546  def from_row_limits(cls, values, row_limits, name=None, validate=True):
547    """Creates a `RaggedTensor` with rows partitioned by `row_limits`.
548
549    Equivalent to: `from_row_splits(values, concat([0, row_limits]))`.
550
551    Args:
552      values: A potentially ragged tensor with shape `[nvals, ...]`.
553      row_limits: A 1-D integer tensor with shape `[nrows]`.  Must be sorted in
554        ascending order.  If `nrows>0`, then `row_limits[-1]` must be `nvals`.
555      name: A name prefix for the RaggedTensor (optional).
556      validate: If true, then use assertions to check that the arguments form
557        a valid `RaggedTensor`.  Note: these assertions incur a runtime cost,
558          since they must be checked for each tensor value.
559
560    Returns:
561      A `RaggedTensor`.  `result.rank = values.rank + 1`.
562      `result.ragged_rank = values.ragged_rank + 1`.
563
564    #### Example:
565
566    >>> print(tf.RaggedTensor.from_row_limits(
567    ...     values=[3, 1, 4, 1, 5, 9, 2, 6],
568    ...     row_limits=[4, 4, 7, 8, 8]))
569    <tf.RaggedTensor [[3, 1, 4, 1], [], [5, 9, 2], [6], []]>
570
571    """
572    if not isinstance(validate, bool):
573      raise TypeError(f"Argument `validate` must have type bool. "
574                      f"Received {validate}.")
575    with ops.name_scope(name, "RaggedFromRowLimits", [values, row_limits]):
576      values = _convert_to_ragged_tensor_values(values)
577      row_partition = RowPartition.from_row_limits(
578          row_limits=row_limits,
579          validate=validate,
580          dtype_hint=_get_optional_partition_dtype(values))
581      return cls._from_row_partition(values, row_partition, validate=validate)
582
583  @classmethod
584  @dispatch.add_dispatch_support
585  def from_uniform_row_length(cls,
586                              values,
587                              uniform_row_length,
588                              nrows=None,
589                              validate=True,
590                              name=None):
591    """Creates a `RaggedTensor` with rows partitioned by `uniform_row_length`.
592
593    This method can be used to create `RaggedTensor`s with multiple uniform
594    outer dimensions.  For example, a `RaggedTensor` with shape `[2, 2, None]`
595    can be constructed with this method from a `RaggedTensor` values with shape
596    `[4, None]`:
597
598    >>> values = tf.ragged.constant([[1, 2, 3], [4], [5, 6], [7, 8, 9, 10]])
599    >>> print(values.shape)
600    (4, None)
601    >>> rt1 = tf.RaggedTensor.from_uniform_row_length(values, 2)
602    >>> print(rt1)
603    <tf.RaggedTensor [[[1, 2, 3], [4]], [[5, 6], [7, 8, 9, 10]]]>
604    >>> print(rt1.shape)
605    (2, 2, None)
606
607    Note that `rt1` only contains one ragged dimension (the innermost
608    dimension). In contrast, if `from_row_splits` is used to construct a similar
609    `RaggedTensor`, then that `RaggedTensor` will have two ragged dimensions:
610
611    >>> rt2 = tf.RaggedTensor.from_row_splits(values, [0, 2, 4])
612    >>> print(rt2.shape)
613    (2, None, None)
614
615    Args:
616      values: A potentially ragged tensor with shape `[nvals, ...]`.
617      uniform_row_length: A scalar integer tensor.  Must be nonnegative. The
618        size of the outer axis of `values` must be evenly divisible by
619        `uniform_row_length`.
620      nrows: The number of rows in the constructed RaggedTensor.  If not
621        specified, then it defaults to `nvals/uniform_row_length` (or `0` if
622        `uniform_row_length==0`).  `nrows` only needs to be specified if
623        `uniform_row_length` might be zero.  `uniform_row_length*nrows` must be
624        `nvals`.
625      validate: If true, then use assertions to check that the arguments form
626        a valid `RaggedTensor`.  Note: these assertions incur a runtime cost,
627          since they must be checked for each tensor value.
628      name: A name prefix for the RaggedTensor (optional).
629
630    Returns:
631      A `RaggedTensor` that corresponds with the python list defined by:
632
633      ```python
634      result = [[values.pop(0) for i in range(uniform_row_length)]
635                for _ in range(nrows)]
636      ```
637
638      `result.rank = values.rank + 1`.
639      `result.ragged_rank = values.ragged_rank + 1`.
640    """
641    if not isinstance(validate, bool):
642      raise TypeError(f"Argument `validate` must have type bool. "
643                      f"Received {validate}.")
644    with ops.name_scope(name, "RaggedFromUniformRowLength",
645                        [values, uniform_row_length, nrows]):
646      values = _convert_to_ragged_tensor_values(values)
647      uniform_row_length = _convert_row_partition(
648          uniform_row_length, "UniformRowLength",
649          _get_optional_partition_dtype(values))
650      nvals = _nvals_uniform_row_length(values, uniform_row_length)
651      row_partition = RowPartition.from_uniform_row_length(
652          uniform_row_length=uniform_row_length,
653          nvals=nvals,
654          nrows=nrows,
655          validate=validate,
656          dtype_hint=_get_optional_partition_dtype(values))
657      return cls._from_row_partition(values, row_partition, validate=validate)
658
659  @classmethod
660  @dispatch.add_dispatch_support
661  def from_nested_value_rowids(cls,
662                               flat_values,
663                               nested_value_rowids,
664                               nested_nrows=None,
665                               name=None,
666                               validate=True):
667    """Creates a `RaggedTensor` from a nested list of `value_rowids` tensors.
668
669    Equivalent to:
670
671    ```python
672    result = flat_values
673    for (rowids, nrows) in reversed(zip(nested_value_rowids, nested_nrows)):
674      result = from_value_rowids(result, rowids, nrows)
675    ```
676
677    Args:
678      flat_values: A potentially ragged tensor.
679      nested_value_rowids: A list of 1-D integer tensors.  The `i`th tensor is
680        used as the `value_rowids` for the `i`th ragged dimension.
681      nested_nrows: A list of integer scalars.  The `i`th scalar is used as the
682        `nrows` for the `i`th ragged dimension.
683      name: A name prefix for the RaggedTensor (optional).
684      validate: If true, then use assertions to check that the arguments form
685        a valid `RaggedTensor`.  Note: these assertions incur a runtime cost,
686          since they must be checked for each tensor value.
687
688    Returns:
689      A `RaggedTensor` (or `flat_values` if `nested_value_rowids` is empty).
690
691    Raises:
692      ValueError: If `len(nested_values_rowids) != len(nested_nrows)`.
693    """
694    if not isinstance(validate, bool):
695      raise TypeError(f"Argument `validate` must have type bool. "
696                      f"Received {validate}.")
697    if isinstance(nested_value_rowids, ops.Tensor):
698      raise TypeError(f"Argument `nested_value_rowids` must be a list of "
699                      f"Tensors. Received {nested_value_rowids}.")
700    if nested_nrows is None:
701      nested_nrows = [None] * len(nested_value_rowids)
702    else:
703      if isinstance(nested_nrows, ops.Tensor):
704        raise TypeError(f"Argument `nested_nrows` must be a list of "
705                        f"Tensors. Received {nested_nrows}.")
706      if len(nested_nrows) != len(nested_value_rowids):
707        raise ValueError(
708            f"Argument `nested_nrows` must have the same length as "
709            f"argument `nested_value_rowids`. len(nested_nrows) = "
710            f"{len(nested_nrows)} vs. len(nested_values_rowids) = "
711            f"{len(nested_value_rowids)}.")
712
713    with ops.name_scope(name, "RaggedFromNestedValueRowIds", [flat_values] +
714                        list(nested_value_rowids) + list(nested_nrows)):
715      result = flat_values
716      for value_rowids, nrows in reversed(
717          list(zip(nested_value_rowids, nested_nrows))):
718        result = cls.from_value_rowids(
719            result, value_rowids, nrows, validate=validate)
720      return result
721
722  @classmethod
723  @dispatch.add_dispatch_support
724  def from_nested_row_splits(cls,
725                             flat_values,
726                             nested_row_splits,
727                             name=None,
728                             validate=True):
729    """Creates a `RaggedTensor` from a nested list of `row_splits` tensors.
730
731    Equivalent to:
732
733    ```python
734    result = flat_values
735    for row_splits in reversed(nested_row_splits):
736      result = from_row_splits(result, row_splits)
737    ```
738
739    Args:
740      flat_values: A potentially ragged tensor.
741      nested_row_splits: A list of 1-D integer tensors.  The `i`th tensor is
742        used as the `row_splits` for the `i`th ragged dimension.
743      name: A name prefix for the RaggedTensor (optional).
744      validate: If true, then use assertions to check that the arguments form
745        a valid `RaggedTensor`.  Note: these assertions incur a runtime cost,
746          since they must be checked for each tensor value.
747
748    Returns:
749      A `RaggedTensor` (or `flat_values` if `nested_row_splits` is empty).
750    """
751    if not isinstance(validate, bool):
752      raise TypeError(f"Argument `validate` must have type bool. "
753                      f"Received {validate}.")
754    if isinstance(nested_row_splits, ops.Tensor):
755      raise TypeError(f"Argument `nested_row_splits` must be a list of "
756                      f"Tensors. Received {nested_row_splits}.")
757    with ops.name_scope(name, "RaggedFromNestedRowSplits",
758                        [flat_values] + list(nested_row_splits)):
759      result = flat_values
760      for splits in reversed(nested_row_splits):
761        result = cls.from_row_splits(result, splits, validate=validate)
762      return result
763
764  @classmethod
765  @dispatch.add_dispatch_support
766  def from_nested_row_lengths(cls,
767                              flat_values,
768                              nested_row_lengths,
769                              name=None,
770                              validate=True):
771    """Creates a `RaggedTensor` from a nested list of `row_lengths` tensors.
772
773    Equivalent to:
774
775    ```python
776    result = flat_values
777    for row_lengths in reversed(nested_row_lengths):
778      result = from_row_lengths(result, row_lengths)
779    ```
780
781    Args:
782      flat_values: A potentially ragged tensor.
783      nested_row_lengths: A list of 1-D integer tensors.  The `i`th tensor is
784        used as the `row_lengths` for the `i`th ragged dimension.
785      name: A name prefix for the RaggedTensor (optional).
786      validate: If true, then use assertions to check that the arguments form
787        a valid `RaggedTensor`.  Note: these assertions incur a runtime cost,
788          since they must be checked for each tensor value.
789
790    Returns:
791      A `RaggedTensor` (or `flat_values` if `nested_row_lengths` is empty).
792    """
793    if not isinstance(validate, bool):
794      raise TypeError(f"Argument `validate` must have type bool. "
795                      f"Received {validate}.")
796    if isinstance(nested_row_lengths, ops.Tensor):
797      raise TypeError(f"Argument `nested_row_lengths` must be a list of "
798                      f"Tensors. Received {nested_row_lengths}.")
799    with ops.name_scope(name, "RaggedFromNestedRowlengths",
800                        [flat_values] + list(nested_row_lengths)):
801      result = flat_values
802      for lengths in reversed(nested_row_lengths):
803        result = cls.from_row_lengths(result, lengths, validate=validate)
804      return result
805
806  @classmethod
807  def _from_nested_row_partitions(cls,
808                                  flat_values,
809                                  nested_row_partitions,
810                                  name=None,
811                                  validate=True):
812    """Creates a `RaggedTensor` from a nested list of row partitions.
813
814    Equivalent to:
815
816    ```python
817    result = flat_values
818    for row_partition in reversed(nested_row_partitions):
819      result = _from_row_partition(result, row_partition)
820    ```
821
822    Args:
823      flat_values: A potentially ragged tensor.
824      nested_row_partitions: A list of row partitions.  The `i`th element is
825        used as the row partition for the `i`th ragged dimension.
826      name: A name prefix for the RaggedTensor (optional).
827      validate: If true, then use assertions to check that the arguments form
828        a valid `RaggedTensor`.  Note: these assertions incur a runtime cost,
829          since they must be checked for each tensor value.
830
831    Returns:
832      A `RaggedTensor` (or `flat_values` if `nested_row_lengths` is empty).
833    """
834    if not isinstance(validate, bool):
835      raise TypeError(f"Argument `validate` must have type bool. "
836                      f"Received {validate}.")
837    if isinstance(nested_row_partitions, RowPartition):
838      raise TypeError(f"Argument `nested_row_partitions` must be a list of "
839                      f"RowPartitions. Received {nested_row_partitions}.")
840    if isinstance(nested_row_partitions, ops.Tensor):
841      raise TypeError(f"Argument `nested_row_partitions` must be a list of "
842                      f"RowPartitions. Received {nested_row_partitions}.")
843    with ops.name_scope(name, "RaggedFromNestedRowPartitions",
844                        [flat_values] + list(nested_row_partitions)):
845      result = flat_values
846      for partition in reversed(nested_row_partitions):
847        result = cls._from_row_partition(result, partition, validate=validate)
848      return result
849
850  @classmethod
851  def _convert_values_and_partition(cls, values, row_partition, name):
852    """Converts `values` and `partition` to Tensors.
853
854    If `values` is a `RaggedTensor`, then converts `values` and `partition`
855    to have compatible row-partitioning dtypes.  In particular, if any of the
856    row partitioning tensors are `int64`, then all of the other row
857    partitioning tensors wil be cast to `int64` (if auto_cast_partition_dtype()
858    is true) or an error will be raised (if auto_cast_partition_dtype() is
859    false).
860
861    Args:
862      values: The `values` for the `RaggedTensor` being constructed.
863      row_partition: A RowPartition object for the `RaggedTensor` being
864        constructed.
865      name: The name of the RowPartition object.
866
867    Returns:
868      A tuple (values, partition).
869    """
870    if not isinstance(row_partition, RowPartition):
871      raise TypeError(f"Argument `row_partition` must be a RowPartition. "
872                      f"Received {row_partition}.")
873    if isinstance(values, RaggedTensor):
874      # pylint: disable=protected-access
875      if values._row_partition.dtype != row_partition.dtype:
876        if not ragged_config.auto_cast_partition_dtype():
877          # pylint: disable=protected-access
878          # TODO(edloper): get rid of the `name` parameter.
879          raise ValueError(
880              f"Argument `row_partition` of RaggedTensor with name: {name} "
881              f"must have same dtype as Argument `values`. "
882              f"({row_partition.dtype} vs. {values._row_partition.dtype}).")
883        values = values.with_row_splits_dtype(row_partition.dtype)
884    else:
885      values = _convert_to_ragged_tensor_values(values)
886
887    return (values, row_partition)
888
889  #=============================================================================
890  # Accessors
891  #=============================================================================
892
893  @property
894  def dtype(self):
895    """The `DType` of values in this tensor."""
896    return self._values.dtype
897
898  @property
899  def shape(self):
900    """The statically known shape of this ragged tensor.
901
902    Returns:
903      A `TensorShape` containing the statically known shape of this ragged
904      tensor.  Ragged dimensions have a size of `None`.
905
906    Examples:
907
908    >>> tf.ragged.constant([[0], [1, 2]]).shape
909    TensorShape([2, None])
910
911    >>> tf.ragged.constant([[[0, 1]], [[1, 2], [3, 4]]], ragged_rank=1).shape
912    TensorShape([2, None, 2])
913
914    """
915    nrows = self._row_partition.static_nrows
916    ncols = self._row_partition.static_uniform_row_length
917    value_shape = self._values.shape[1:]
918    return tensor_shape.TensorShape([nrows, ncols]).concatenate(value_shape)
919
920  def get_shape(self):
921    """The statically known shape of this ragged tensor.
922
923    Returns:
924      A `TensorShape` containing the statically known shape of this ragged
925      tensor.  Ragged dimensions have a size of `None`.
926
927    Alias for `shape` property.
928
929    Examples:
930
931    >>> tf.ragged.constant([[0], [1, 2]]).get_shape()
932    TensorShape([2, None])
933
934    >>> tf.ragged.constant(
935    ...    [[[0, 1]], [[1, 2], [3, 4]]], ragged_rank=1).get_shape()
936    TensorShape([2, None, 2])
937
938    """
939    return self.shape
940
941  @property
942  def ragged_rank(self):
943    """The number of times the RaggedTensor's flat_values is partitioned.
944
945    Examples:
946
947    >>> values = tf.ragged.constant([[1, 2, 3], [4], [5, 6], [7, 8, 9, 10]])
948    >>> values.ragged_rank
949    1
950
951    >>> rt = tf.RaggedTensor.from_uniform_row_length(values, 2)
952    >>> rt.ragged_rank
953    2
954
955    Returns:
956      A Python `int` indicating the number of times the underlying `flat_values`
957      Tensor has been partitioned to add a new dimension.
958      I.e., `tf.rank(rt) = tf.rank(rt.flat_values) + rt.ragged_rank`.
959    """
960    values_is_ragged = isinstance(self._values, RaggedTensor)
961    return self._values.ragged_rank + 1 if values_is_ragged else 1
962
963  @property
964  def values(self):
965    """The concatenated rows for this ragged tensor.
966
967    `rt.values` is a potentially ragged tensor formed by flattening the two
968    outermost dimensions of `rt` into a single dimension.
969
970    `rt.values.shape = [nvals] + rt.shape[2:]` (where `nvals` is the
971    number of items in the outer two dimensions of `rt`).
972
973    `rt.ragged_rank = self.ragged_rank - 1`
974
975    Returns:
976      A potentially ragged tensor.
977
978    #### Example:
979
980    >>> rt = tf.ragged.constant([[3, 1, 4, 1], [], [5, 9, 2], [6], []])
981    >>> print(rt.values)
982    tf.Tensor([3 1 4 1 5 9 2 6], shape=(8,), dtype=int32)
983
984    """
985    return self._values
986
987  @property
988  def _nested_row_partitions(self):
989    """Returns the row partitions for this `RaggedTensor`."""
990    partitions = [self._row_partition]
991    rt_values = self.values
992    while isinstance(rt_values, RaggedTensor):
993      # pylint: disable=protected-access
994      partitions.append(rt_values._row_partition)
995      rt_values = rt_values.values
996    return tuple(partitions)
997
998  @property
999  def row_splits(self):
1000    """The row-split indices for this ragged tensor's `values`.
1001
1002    `rt.row_splits` specifies where the values for each row begin and end in
1003    `rt.values`.  In particular, the values for row `rt[i]` are stored in
1004    the slice `rt.values[rt.row_splits[i]:rt.row_splits[i+1]]`.
1005
1006    Returns:
1007      A 1-D integer `Tensor` with shape `[self.nrows+1]`.
1008      The returned tensor is non-empty, and is sorted in ascending order.
1009      `self.row_splits[0]` is zero, and `self.row_splits[-1]` is equal to
1010      `self.values.shape[0]`.
1011
1012    #### Example:
1013
1014    >>> rt = tf.ragged.constant([[3, 1, 4, 1], [], [5, 9, 2], [6], []])
1015    >>> print(rt.row_splits)  # indices of row splits in rt.values
1016    tf.Tensor([0 4 4 7 8 8], shape=(6,), dtype=int64)
1017
1018    """
1019    return self._row_partition.row_splits()
1020
1021  @property
1022  def uniform_row_length(self):
1023    """The length of each row in this ragged tensor, or None if rows are ragged.
1024
1025    >>> rt1 = tf.ragged.constant([[1, 2, 3], [4], [5, 6], [7, 8, 9, 10]])
1026    >>> print(rt1.uniform_row_length)  # rows are ragged.
1027    None
1028
1029    >>> rt2 = tf.RaggedTensor.from_uniform_row_length(
1030    ...     values=rt1, uniform_row_length=2)
1031    >>> print(rt2)
1032    <tf.RaggedTensor [[[1, 2, 3], [4]], [[5, 6], [7, 8, 9, 10]]]>
1033    >>> print(rt2.uniform_row_length)  # rows are not ragged (all have size 2).
1034    tf.Tensor(2, shape=(), dtype=int64)
1035
1036    A RaggedTensor's rows are only considered to be uniform (i.e. non-ragged)
1037    if it can be determined statically (at graph construction time) that the
1038    rows all have the same length.
1039
1040    Returns:
1041      A scalar integer `Tensor`, specifying the length of every row in this
1042      ragged tensor (for ragged tensors whose rows are uniform); or `None`
1043      (for ragged tensors whose rows are ragged).
1044    """
1045    return self._row_partition.uniform_row_length()
1046
1047  @property
1048  def flat_values(self):
1049    """The innermost `values` tensor for this ragged tensor.
1050
1051    Concretely, if `rt.values` is a `Tensor`, then `rt.flat_values` is
1052    `rt.values`; otherwise, `rt.flat_values` is `rt.values.flat_values`.
1053
1054    Conceptually, `flat_values` is the tensor formed by flattening the
1055    outermost dimension and all of the ragged dimensions into a single
1056    dimension.
1057
1058    `rt.flat_values.shape = [nvals] + rt.shape[rt.ragged_rank + 1:]`
1059    (where `nvals` is the number of items in the flattened dimensions).
1060
1061    Returns:
1062      A `Tensor`.
1063
1064    #### Example:
1065
1066    >>> rt = tf.ragged.constant([[[3, 1, 4, 1], [], [5, 9, 2]], [], [[6], []]])
1067    >>> print(rt.flat_values)
1068    tf.Tensor([3 1 4 1 5 9 2 6], shape=(8,), dtype=int32)
1069
1070    """
1071    rt_values = self.values
1072    while isinstance(rt_values, RaggedTensor):
1073      rt_values = rt_values.values
1074    return rt_values
1075
1076  @property
1077  def nested_row_splits(self):
1078    """A tuple containing the row_splits for all ragged dimensions.
1079
1080    `rt.nested_row_splits` is a tuple containing the `row_splits` tensors for
1081    all ragged dimensions in `rt`, ordered from outermost to innermost.  In
1082    particular, `rt.nested_row_splits = (rt.row_splits,) + value_splits` where:
1083
1084        * `value_splits = ()` if `rt.values` is a `Tensor`.
1085        * `value_splits = rt.values.nested_row_splits` otherwise.
1086
1087    Returns:
1088      A `tuple` of 1-D integer `Tensor`s.
1089
1090    #### Example:
1091
1092    >>> rt = tf.ragged.constant(
1093    ...     [[[[3, 1, 4, 1], [], [5, 9, 2]], [], [[6], []]]])
1094    >>> for i, splits in enumerate(rt.nested_row_splits):
1095    ...   print('Splits for dimension %d: %s' % (i+1, splits.numpy()))
1096    Splits for dimension 1: [0 3]
1097    Splits for dimension 2: [0 3 3 5]
1098    Splits for dimension 3: [0 4 4 7 8 8]
1099
1100    """
1101    rt_nested_splits = [self.row_splits]
1102    rt_values = self.values
1103    while isinstance(rt_values, RaggedTensor):
1104      rt_nested_splits.append(rt_values.row_splits)
1105      rt_values = rt_values.values
1106    return tuple(rt_nested_splits)
1107
1108  def value_rowids(self, name=None):
1109    """Returns the row indices for the `values` in this ragged tensor.
1110
1111    `rt.value_rowids()` corresponds one-to-one with the outermost dimension of
1112    `rt.values`, and specifies the row containing each value.  In particular,
1113    the row `rt[row]` consists of the values `rt.values[j]` where
1114    `rt.value_rowids()[j] == row`.
1115
1116    Args:
1117      name: A name prefix for the returned tensor (optional).
1118
1119    Returns:
1120      A 1-D integer `Tensor` with shape `self.values.shape[:1]`.
1121      The returned tensor is nonnegative, and is sorted in ascending order.
1122
1123    #### Example:
1124
1125    >>> rt = tf.ragged.constant([[3, 1, 4, 1], [], [5, 9, 2], [6], []])
1126    >>> print(rt.values)
1127    tf.Tensor([3 1 4 1 5 9 2 6], shape=(8,), dtype=int32)
1128    >>> print(rt.value_rowids())  # corresponds 1:1 with rt.values
1129    tf.Tensor([0 0 0 0 2 2 2 3], shape=(8,), dtype=int64)
1130
1131    """
1132    with ops.name_scope(name, "RaggedValueRowIds", [self]):
1133      return self._row_partition.value_rowids()
1134
1135  def nested_value_rowids(self, name=None):
1136    """Returns a tuple containing the value_rowids for all ragged dimensions.
1137
1138    `rt.nested_value_rowids` is a tuple containing the `value_rowids` tensors
1139    for
1140    all ragged dimensions in `rt`, ordered from outermost to innermost.  In
1141    particular, `rt.nested_value_rowids = (rt.value_rowids(),) + value_ids`
1142    where:
1143
1144    * `value_ids = ()` if `rt.values` is a `Tensor`.
1145    * `value_ids = rt.values.nested_value_rowids` otherwise.
1146
1147    Args:
1148      name: A name prefix for the returned tensors (optional).
1149
1150    Returns:
1151      A `tuple` of 1-D integer `Tensor`s.
1152
1153    #### Example:
1154
1155    >>> rt = tf.ragged.constant(
1156    ...     [[[[3, 1, 4, 1], [], [5, 9, 2]], [], [[6], []]]])
1157    >>> for i, ids in enumerate(rt.nested_value_rowids()):
1158    ...   print('row ids for dimension %d: %s' % (i+1, ids.numpy()))
1159    row ids for dimension 1: [0 0 0]
1160    row ids for dimension 2: [0 0 0 2 2]
1161    row ids for dimension 3: [0 0 0 0 2 2 2 3]
1162
1163    """
1164    with ops.name_scope(name, "RaggedNestedValueRowIds", [self]):
1165      rt_nested_ids = [self.value_rowids()]
1166      rt_values = self.values
1167      while isinstance(rt_values, RaggedTensor):
1168        rt_nested_ids.append(rt_values.value_rowids())
1169        rt_values = rt_values.values
1170      return tuple(rt_nested_ids)
1171
1172  def nrows(self, out_type=None, name=None):
1173    """Returns the number of rows in this ragged tensor.
1174
1175    I.e., the size of the outermost dimension of the tensor.
1176
1177    Args:
1178      out_type: `dtype` for the returned tensor.  Defaults to
1179        `self.row_splits.dtype`.
1180      name: A name prefix for the returned tensor (optional).
1181
1182    Returns:
1183      A scalar `Tensor` with dtype `out_type`.
1184
1185    #### Example:
1186
1187    >>> rt = tf.ragged.constant([[3, 1, 4, 1], [], [5, 9, 2], [6], []])
1188    >>> print(rt.nrows())  # rt has 5 rows.
1189    tf.Tensor(5, shape=(), dtype=int64)
1190
1191    """
1192    with ops.name_scope(name, "RaggedNRows", [self]):
1193      if out_type is None:
1194        return self._row_partition.nrows()
1195      else:
1196        return math_ops.cast(self._row_partition.nrows(), dtype=out_type)
1197
1198  def row_starts(self, name=None):
1199    """Returns the start indices for rows in this ragged tensor.
1200
1201    These indices specify where the values for each row begin in
1202    `self.values`.  `rt.row_starts()` is equal to `rt.row_splits[:-1]`.
1203
1204    Args:
1205      name: A name prefix for the returned tensor (optional).
1206
1207    Returns:
1208      A 1-D integer Tensor with shape `[nrows]`.
1209      The returned tensor is nonnegative, and is sorted in ascending order.
1210
1211    #### Example:
1212
1213    >>> rt = tf.ragged.constant([[3, 1, 4, 1], [], [5, 9, 2], [6], []])
1214    >>> print(rt.values)
1215    tf.Tensor([3 1 4 1 5 9 2 6], shape=(8,), dtype=int32)
1216    >>> print(rt.row_starts())  # indices of row starts in rt.values
1217    tf.Tensor([0 4 4 7 8], shape=(5,), dtype=int64)
1218
1219    """
1220    with ops.name_scope(name, "RaggedRowStarts", [self]):
1221      return self._row_partition.row_starts()
1222
1223  def row_limits(self, name=None):
1224    """Returns the limit indices for rows in this ragged tensor.
1225
1226    These indices specify where the values for each row end in
1227    `self.values`.  `rt.row_limits(self)` is equal to `rt.row_splits[:-1]`.
1228
1229    Args:
1230      name: A name prefix for the returned tensor (optional).
1231
1232    Returns:
1233      A 1-D integer Tensor with shape `[nrows]`.
1234      The returned tensor is nonnegative, and is sorted in ascending order.
1235
1236    #### Example:
1237
1238    >>> rt = tf.ragged.constant([[3, 1, 4, 1], [], [5, 9, 2], [6], []])
1239    >>> print(rt.values)
1240    tf.Tensor([3 1 4 1 5 9 2 6], shape=(8,), dtype=int32)
1241    >>> print(rt.row_limits())  # indices of row limits in rt.values
1242    tf.Tensor([4 4 7 8 8], shape=(5,), dtype=int64)
1243
1244    """
1245    with ops.name_scope(name, "RaggedRowLimits", [self]):
1246      return self._row_partition.row_limits()
1247
1248  def row_lengths(self, axis=1, name=None):
1249    """Returns the lengths of the rows in this ragged tensor.
1250
1251    `rt.row_lengths()[i]` indicates the number of values in the
1252    `i`th row of `rt`.
1253
1254    Args:
1255      axis: An integer constant indicating the axis whose row lengths should be
1256        returned.
1257      name: A name prefix for the returned tensor (optional).
1258
1259    Returns:
1260      A potentially ragged integer Tensor with shape `self.shape[:axis]`.
1261
1262    Raises:
1263      ValueError: If `axis` is out of bounds.
1264
1265    #### Example:
1266
1267    >>> rt = tf.ragged.constant(
1268    ...     [[[3, 1, 4], [1]], [], [[5, 9], [2]], [[6]], []])
1269    >>> print(rt.row_lengths())  # lengths of rows in rt
1270    tf.Tensor([2 0 2 1 0], shape=(5,), dtype=int64)
1271    >>> print(rt.row_lengths(axis=2))  # lengths of axis=2 rows.
1272    <tf.RaggedTensor [[3, 1], [], [2, 1], [1], []]>
1273
1274    """
1275    if axis == 0:
1276      return self._row_partition.nrows()
1277
1278    if axis == 1:
1279      return self._row_partition.row_lengths()
1280
1281    with ops.name_scope(name, "RaggedRowLengths", [self]):
1282      axis = array_ops.get_positive_axis(
1283          axis, self.shape.rank, ndims_name="rank(self)")
1284      if axis == 0:
1285        return self.nrows()
1286      elif axis == 1:
1287        splits = self.row_splits
1288        return splits[1:] - splits[:-1]
1289      elif isinstance(self.values, RaggedTensor):
1290        return self.with_values(self.values.row_lengths(axis - 1))
1291      else:
1292        shape = array_ops.shape(self.values, out_type=self._row_partition.dtype)
1293        return self.with_values(
1294            array_ops.ones(shape[:axis - 1], self._row_partition.dtype) *
1295            shape[axis - 1])
1296
1297  def nested_row_lengths(self, name=None):
1298    """Returns a tuple containing the row_lengths for all ragged dimensions.
1299
1300    `rt.nested_row_lengths()` is a tuple containing the `row_lengths` tensors
1301    for all ragged dimensions in `rt`, ordered from outermost to innermost.
1302
1303    Args:
1304      name: A name prefix for the returned tensors (optional).
1305
1306    Returns:
1307      A `tuple` of 1-D integer `Tensors`.  The length of the tuple is equal to
1308      `self.ragged_rank`.
1309    """
1310    with ops.name_scope(name, "RaggedNestedRowLengths", [self]):
1311      rt_nested_row_lengths = []
1312      rt = self
1313      while isinstance(rt, RaggedTensor):
1314        rt_nested_row_lengths.append(rt.row_lengths())
1315        rt = rt.values
1316      return tuple(rt_nested_row_lengths)
1317
1318  def bounding_shape(self, axis=None, name=None, out_type=None):
1319    """Returns the tight bounding box shape for this `RaggedTensor`.
1320
1321    Args:
1322      axis: An integer scalar or vector indicating which axes to return the
1323        bounding box for.  If not specified, then the full bounding box is
1324        returned.
1325      name: A name prefix for the returned tensor (optional).
1326      out_type: `dtype` for the returned tensor.  Defaults to
1327        `self.row_splits.dtype`.
1328
1329    Returns:
1330      An integer `Tensor` (`dtype=self.row_splits.dtype`).  If `axis` is not
1331      specified, then `output` is a vector with
1332      `output.shape=[self.shape.ndims]`.  If `axis` is a scalar, then the
1333      `output` is a scalar.  If `axis` is a vector, then `output` is a vector,
1334      where `output[i]` is the bounding size for dimension `axis[i]`.
1335
1336    #### Example:
1337
1338    >>> rt = tf.ragged.constant([[1, 2, 3, 4], [5], [], [6, 7, 8, 9], [10]])
1339    >>> rt.bounding_shape().numpy()
1340    array([5, 4])
1341
1342    """
1343    if out_type is None:
1344      out_type = self._row_partition.dtype
1345    else:
1346      out_type = dtypes.as_dtype(out_type)
1347    with ops.name_scope(name, "RaggedBoundingBox", [self, axis]):
1348      nested_splits = self.nested_row_splits
1349      rt_flat_values = self.flat_values
1350
1351      # Optimized special cases for when axis=0 or axis=1:
1352      if isinstance(axis, int):
1353        if axis == 0:
1354          return array_ops.shape(nested_splits[0], out_type=out_type)[0] - 1
1355        elif axis == 1:
1356          result = math_ops.maximum(math_ops.reduce_max(self.row_lengths()), 0)
1357          if out_type != self._row_partition.dtype:
1358            result = math_ops.cast(result, out_type)
1359          return result
1360
1361      splits_shape = array_ops.shape(self.row_splits, out_type=out_type)
1362      flat_values_shape = array_ops.shape(rt_flat_values, out_type=out_type)
1363
1364      ragged_dimensions = [splits_shape[0] - 1] + [
1365          math_ops.maximum(math_ops.reduce_max(splits[1:] - splits[:-1]), 0)
1366          for splits in nested_splits
1367      ]
1368      inner_dimensions = flat_values_shape[1:]
1369
1370      if out_type != self._row_partition.dtype:
1371        ragged_dimensions = [
1372            math_ops.cast(d, out_type) for d in ragged_dimensions
1373        ]
1374      bbox = array_ops.concat(
1375          [array_ops.stack(ragged_dimensions), inner_dimensions], axis=0)
1376      return bbox if axis is None else array_ops.gather(bbox, axis)
1377
1378  #=============================================================================
1379  # Transformation
1380  #=============================================================================
1381
1382  def with_values(self, new_values):
1383    """Returns a copy of `self` with `values` replaced by `new_value`.
1384
1385    Preserves cached row-partitioning tensors such as `self.cached_nrows` and
1386    `self.cached_value_rowids` if they have values.
1387
1388    Args:
1389      new_values: Potentially ragged tensor to use as the `values` for the
1390        returned `RaggedTensor`.  Must have `rank > 0`, and must have the same
1391        number of rows as `self.values`.
1392
1393    Returns:
1394      A `RaggedTensor`.  `result.rank = 1 + new_values.rank`.
1395      `result.ragged_rank = 1 + new_values.ragged_rank`
1396    """
1397    new_values = _convert_to_ragged_tensor_values(new_values)
1398    new_values.shape.with_rank_at_least(1)
1399    self.values.shape[:1].assert_is_compatible_with(new_values.shape[:1])
1400    if (isinstance(new_values, RaggedTensor) and
1401        self._row_partition.dtype != new_values.row_splits.dtype):
1402      if not ragged_config.auto_cast_partition_dtype():
1403        raise ValueError("self and new_values have mismatched row_splits "
1404                         "dtypes; use RaggedTensor.with_row_splits_dtype() to "
1405                         "convert them to compatible dtypes.")
1406      new_values = new_values.with_row_splits_dtype(dtypes.int64)
1407      return self.with_row_splits_dtype(dtypes.int64).with_values(new_values)
1408    return RaggedTensor(
1409        values=new_values, row_partition=self._row_partition, internal=True)
1410
1411  def with_flat_values(self, new_values):
1412    """Returns a copy of `self` with `flat_values` replaced by `new_value`.
1413
1414    Preserves cached row-partitioning tensors such as `self.cached_nrows` and
1415    `self.cached_value_rowids` if they have values.
1416
1417    Args:
1418      new_values: Potentially ragged tensor that should replace
1419        `self.flat_values`.  Must have `rank > 0`, and must have the same number
1420        of rows as `self.flat_values`.
1421
1422    Returns:
1423      A `RaggedTensor`.
1424      `result.rank = self.ragged_rank + new_values.rank`.
1425      `result.ragged_rank = self.ragged_rank + new_values.ragged_rank`.
1426    """
1427    if isinstance(self._values, RaggedTensor):
1428      return self.with_values(self.values.with_flat_values(new_values))
1429    else:
1430      new_values = _convert_to_ragged_tensor_values(new_values)
1431    return self.with_values(new_values)
1432
1433  def with_row_splits_dtype(self, dtype):
1434    """Returns a copy of this RaggedTensor with the given `row_splits` dtype.
1435
1436    For RaggedTensors with multiple ragged dimensions, the `row_splits` for all
1437    nested `RaggedTensor` objects are cast to the given dtype.
1438
1439    Args:
1440      dtype: The dtype for `row_splits`.  One of `tf.int32` or `tf.int64`.
1441
1442    Returns:
1443      A copy of this RaggedTensor, with the `row_splits` cast to the given
1444      type.
1445    """
1446    dtype = dtypes.as_dtype(dtype)
1447    if dtype not in (dtypes.int32, dtypes.int64):
1448      raise ValueError(f"Argument `row_splits` dtype must be int32 or int64. "
1449                       f"Received {dtype}.")
1450    if self._row_partition.dtype == dtype:
1451      return self
1452    current_values = self._values
1453    if isinstance(current_values, RaggedTensor):
1454      return RaggedTensor(
1455          values=current_values.with_row_splits_dtype(dtype),
1456          row_partition=self._row_partition.with_dtype(dtype),
1457          internal=True)
1458    else:
1459      return RaggedTensor(
1460          values=current_values,
1461          row_partition=self._row_partition.with_dtype(dtype),
1462          internal=True)
1463
1464  def merge_dims(self, outer_axis, inner_axis):
1465    """Merges outer_axis...inner_axis into a single dimension.
1466
1467    Returns a copy of this RaggedTensor with the specified range of dimensions
1468    flattened into a single dimension, with elements in row-major order.
1469
1470    #### Examples:
1471
1472    >>> rt = tf.ragged.constant([[[1, 2], [3]], [[4, 5, 6]]])
1473    >>> print(rt.merge_dims(0, 1))
1474    <tf.RaggedTensor [[1, 2], [3], [4, 5, 6]]>
1475    >>> print(rt.merge_dims(1, 2))
1476    <tf.RaggedTensor [[1, 2, 3], [4, 5, 6]]>
1477    >>> print(rt.merge_dims(0, 2))
1478    tf.Tensor([1 2 3 4 5 6], shape=(6,), dtype=int32)
1479
1480    To mimic the behavior of `np.flatten` (which flattens all dimensions), use
1481    `rt.merge_dims(0, -1).  To mimic the behavior of `tf.layers.Flatten` (which
1482    flattens all dimensions except the outermost batch dimension), use
1483    `rt.merge_dims(1, -1)`.
1484
1485    Args:
1486      outer_axis: `int`: The first dimension in the range of dimensions to
1487        merge. May be negative if `self.shape.rank` is statically known.
1488      inner_axis: `int`: The last dimension in the range of dimensions to merge.
1489        May be negative if `self.shape.rank` is statically known.
1490
1491    Returns:
1492      A copy of this tensor, with the specified dimensions merged into a
1493      single dimension.  The shape of the returned tensor will be
1494      `self.shape[:outer_axis] + [N] + self.shape[inner_axis + 1:]`, where `N`
1495      is the total number of slices in the merged dimensions.
1496    """
1497    outer_axis = array_ops.get_positive_axis(
1498        outer_axis,
1499        self.shape.rank,
1500        axis_name="outer_axis",
1501        ndims_name="rank(self)")
1502    inner_axis = array_ops.get_positive_axis(
1503        inner_axis,
1504        self.shape.rank,
1505        axis_name="inner_axis",
1506        ndims_name="rank(self)")
1507    if not outer_axis <= inner_axis:
1508      raise ValueError(f"Expected outer_axis ({outer_axis}) to be less than or "
1509                       f"equal to inner_axis ({inner_axis}).")
1510    return merge_dims(self, outer_axis, inner_axis)
1511
1512  def _set_shape(self, shape):
1513    """Updates the static shape of `self` to be `shape`.
1514
1515    * If a dimension of `shape` has known rank, and is encoded via
1516      partitioning, then this will update the corresponding partition to
1517      define `_uniform_row_length` and `nrows`.
1518    * If a dimension of `shape` has a known rank, and is encoded as one
1519      of the `flat_values` dimensions, then `flat_values.set_shape()` will
1520      be used to update its shape.
1521
1522    Warning: Using this method to assert an incorrect shape for a RaggedTensor
1523    (i.e., one that's not consistent with its actual shape) can cause
1524    segmentation faults and very difficult-to-diagnose behavior.  Only use this
1525    method if you are certain that the shape is correct.
1526
1527    Args:
1528      shape: `tf.TensorShape` specifying the shape for this `RaggedTensor`.
1529    """
1530    # TODO(edloper): Refactor this to not directly access private members
1531    # of RowPartition.
1532    # pylint: disable=protected-access
1533
1534    shape = tensor_shape.as_shape(shape)
1535    if shape.rank is None:
1536      return  # Nothing to do.
1537
1538    shape = shape.as_list()
1539
1540    # Outermost dimension
1541    if shape[0] is not None:
1542      self._row_partition._row_splits.set_shape(shape[0] + 1)
1543
1544    # Partitioned dimensions
1545    dtype = self._row_partition.dtype
1546    for i, partition in enumerate(self._nested_row_partitions):
1547      size = shape[i + 1]
1548      if size is not None:
1549        if partition._uniform_row_length is not None:
1550          old_row_length = tensor_util.constant_value(
1551              partition._uniform_row_length)
1552          if old_row_length is not None:
1553            if size == old_row_length:
1554              continue  # already have shape info for this axis.
1555            else:
1556              raise ValueError(f"Inconsistent size for axis {i + 1}: "
1557                               f"{old_row_length} vs. {size}.")
1558        partition._uniform_row_length = ops.convert_to_tensor(size, dtype)
1559        if partition._nrows is None:
1560          partition._nrows = array_ops.size(
1561              partition._row_splits, out_type=dtype) - 1
1562
1563    # self.flat_values could be a CompositeTensor and doesn't have set_shape.
1564    if hasattr(self.flat_values, "set_shape"):
1565      # Inner dimensions
1566      flat_shape = tensor_shape.as_shape([None] + shape[self.ragged_rank + 1:])
1567      self.flat_values.set_shape(flat_shape)
1568
1569  #=============================================================================
1570  # Tensor Type Conversions
1571  #=============================================================================
1572
1573  @classmethod
1574  @dispatch.add_dispatch_support
1575  def from_tensor(cls,
1576                  tensor,
1577                  lengths=None,
1578                  padding=None,
1579                  ragged_rank=1,
1580                  name=None,
1581                  row_splits_dtype=dtypes.int64):
1582    """Converts a `tf.Tensor` into a `RaggedTensor`.
1583
1584    The set of absent/default values may be specified using a vector of lengths
1585    or a padding value (but not both).  If `lengths` is specified, then the
1586    output tensor will satisfy `output[row] = tensor[row][:lengths[row]]`. If
1587    'lengths' is a list of lists or tuple of lists, those lists will be used
1588    as nested row lengths. If `padding` is specified, then any row *suffix*
1589    consisting entirely of `padding` will be excluded from the returned
1590    `RaggedTensor`.  If neither `lengths` nor `padding` is specified, then the
1591    returned `RaggedTensor` will have no absent/default values.
1592
1593    Examples:
1594
1595    >>> dt = tf.constant([[5, 7, 0], [0, 3, 0], [6, 0, 0]])
1596    >>> tf.RaggedTensor.from_tensor(dt)
1597    <tf.RaggedTensor [[5, 7, 0], [0, 3, 0], [6, 0, 0]]>
1598    >>> tf.RaggedTensor.from_tensor(dt, lengths=[1, 0, 3])
1599    <tf.RaggedTensor [[5], [], [6, 0, 0]]>
1600
1601    >>> tf.RaggedTensor.from_tensor(dt, padding=0)
1602    <tf.RaggedTensor [[5, 7], [0, 3], [6]]>
1603
1604    >>> dt = tf.constant([[[5, 0], [7, 0], [0, 0]],
1605    ...                   [[0, 0], [3, 0], [0, 0]],
1606    ...                   [[6, 0], [0, 0], [0, 0]]])
1607    >>> tf.RaggedTensor.from_tensor(dt, lengths=([2, 0, 3], [1, 1, 2, 0, 1]))
1608    <tf.RaggedTensor [[[5], [7]], [], [[6, 0], [], [0]]]>
1609
1610    Args:
1611      tensor: The `Tensor` to convert.  Must have rank `ragged_rank + 1` or
1612        higher.
1613      lengths: An optional set of row lengths, specified using a 1-D integer
1614        `Tensor` whose length is equal to `tensor.shape[0]` (the number of rows
1615        in `tensor`).  If specified, then `output[row]` will contain
1616        `tensor[row][:lengths[row]]`.  Negative lengths are treated as zero. You
1617          may optionally pass a list or tuple of lengths to this argument, which
1618          will be used as nested row lengths to construct a ragged tensor with
1619          multiple ragged dimensions.
1620      padding: An optional padding value.  If specified, then any row suffix
1621        consisting entirely of `padding` will be excluded from the returned
1622        RaggedTensor.  `padding` is a `Tensor` with the same dtype as `tensor`
1623        and with `shape=tensor.shape[ragged_rank + 1:]`.
1624      ragged_rank: Integer specifying the ragged rank for the returned
1625        `RaggedTensor`.  Must be greater than zero.
1626      name: A name prefix for the returned tensors (optional).
1627      row_splits_dtype: `dtype` for the returned `RaggedTensor`'s `row_splits`
1628        tensor.  One of `tf.int32` or `tf.int64`.
1629
1630    Returns:
1631      A `RaggedTensor` with the specified `ragged_rank`.  The shape of the
1632      returned ragged tensor is compatible with the shape of `tensor`.
1633
1634    Raises:
1635      ValueError: If both `lengths` and `padding` are specified.
1636      ValueError: If the rank of `tensor` is 0 or 1.
1637    """
1638    row_splits_dtype = dtypes.as_dtype(row_splits_dtype)
1639    if lengths is not None and padding is not None:
1640      raise ValueError("Specify argument `lengths` or `padding`, but not both.")
1641    if not isinstance(ragged_rank, int):
1642      raise TypeError(f"Argument `ragged_rank` must be an int. "
1643                      f"Received {ragged_rank}.")
1644    if ragged_rank <= 0:
1645      raise ValueError(f"Argument `ragged_rank` must be greater than 0. "
1646                       f"Received {ragged_rank}.")
1647
1648    with ops.name_scope(name, "RaggedFromTensor", [tensor, lengths, padding]):
1649      tensor = ops.convert_to_tensor(tensor, name="tensor")
1650      if tensor.shape.rank is not None and tensor.shape.rank < 2:
1651        raise ValueError(f"The rank of a RaggedTensor must be greater than 1, "
1652                         f"i.e., a list of scalars won't have ragged "
1653                         f"dimensions. Received argument `tensor` with rank "
1654                         f"{tensor.shape.rank}.")
1655      tensor.shape.with_rank_at_least(ragged_rank + 1)
1656      input_shape = array_ops.shape(tensor, out_type=row_splits_dtype)
1657      ncols = input_shape[1]
1658
1659      # Handle nested row lengths.
1660      if (lengths is not None and isinstance(lengths, (list, tuple)) and
1661          len(lengths) and not isinstance(lengths[0], (int, float))):
1662        if ragged_rank not in (1, len(lengths)):
1663          # Note: we accept `ragged_rank=1` here because it's the default value;
1664          # i.e., if the user passes in a tuple of lengths, but doesn't specify
1665          # ragged_rank, then we should use that tuple to determine ragged_rank.
1666          # We only want to complain if they pass in an explicit ragged_rank
1667          # that doesn't match len(lengths).
1668          raise ValueError(f"If Argument `lengths` is a tuple of row_lengths, "
1669                           f"argument `ragged_rank` must be "
1670                           f"len(lengths): {len(lengths)}. Received "
1671                           f"ragged_rank: {ragged_rank}.")
1672        # Rather than reconstructing the tensor mask directly, we can
1673        # recreate it as a boolean RaggedTensor, then densify that and use
1674        # that as the mask to clear out the unused data in the passed tensor.
1675        tensor.shape.with_rank_at_least(len(lengths) + 1)
1676        num_tokens = math_ops.reduce_sum(lengths[-1])
1677        ones_mask = array_ops.ones([num_tokens], dtype=dtypes.bool)
1678        ragged_mask = cls.from_nested_row_lengths(
1679            ones_mask, lengths, validate=False)
1680        dense_ragged_mask = ragged_mask.to_tensor(default_value=False)
1681        masked_data = array_ops.boolean_mask(tensor, dense_ragged_mask)
1682        return cls.from_nested_row_lengths(masked_data, lengths, validate=False)
1683
1684      # Handle ragged_rank>1 via recursion:
1685      # If the output should have multiple ragged dimensions, then first
1686      # flatten the tensor to eliminate all but the last ragged dimension,
1687      # and recursively convert that flattened tensor.  Then add on the splits
1688      # for the dimensions that we flattened out.
1689      if ragged_rank > 1:
1690        if tensor.shape.is_fully_defined():
1691          input_shape = tensor.shape.as_list()
1692          # The total number of elements in each  dimension.  E.g., if
1693          # input_shape=[3, 4, 5, 6], then dim[2] has 3*4*5 elements in total.
1694          dim_size = np.cumprod(input_shape)
1695          new_shape = [dim_size[ragged_rank - 1]] + input_shape[ragged_rank:]
1696        else:
1697          dim_size = math_ops.cumprod(input_shape)
1698          new_shape = array_ops.concat(
1699              [[dim_size[ragged_rank - 1]], input_shape[ragged_rank:]], axis=0)
1700        flattened = array_ops.reshape(tensor, new_shape)
1701        result = cls.from_tensor(
1702            flattened, lengths, padding, row_splits_dtype=row_splits_dtype)
1703
1704        for axis in range(ragged_rank - 1, 0, -1):
1705          dim_len = tensor_shape.dimension_at_index(tensor.shape, axis).value
1706          if dim_len is None:
1707            dim_len = input_shape[axis]
1708          else:
1709            dim_len = constant_op.constant(dim_len, row_splits_dtype)
1710          result = RaggedTensor.from_uniform_row_length(
1711              values=result,
1712              uniform_row_length=dim_len,
1713              nrows=dim_size[axis - 1],
1714              validate=False)
1715        return result
1716
1717      # If padding was specified, then use it to find row lengths.
1718      if padding is not None:
1719        padding = ops.convert_to_tensor(
1720            padding, name="padding", dtype=tensor.dtype)
1721        padding.shape.assert_is_compatible_with(tensor.shape[2:])
1722
1723        # Find places where the padding is equal to the tensor.  (This will
1724        # broadcast `padding` across the outermost 2 dimensions of `tensor`,
1725        # so `has_default_value.shape = tensor.shape`.)
1726        has_default_value = math_ops.equal(padding, tensor)
1727
1728        # If the padding isn't a scalar, then require that all values in the
1729        # padding match each item in the tensor.  After this block of code,
1730        # `has_default.shape = tensor.shape[:2]`.  (Unfortunately, we can't just
1731        # use reduce_all for both cases, becaue when you pass an empty `axis`
1732        # list to reduce_all, it reduces all axes; but we want it to reduce no
1733        # axes -- i.e., to be a no-op.)
1734        tensor_rank = array_ops.rank(tensor)
1735        reduce_axis = math_ops.range(2, tensor_rank)
1736        has_default = control_flow_ops.cond(
1737            tensor_rank > 2,
1738            lambda: math_ops.reduce_all(has_default_value, axis=reduce_axis),
1739            lambda: has_default_value)
1740        has_default.set_shape(tensor_shape.TensorShape([None, None]))
1741        has_default.set_shape(tensor.shape[:2])
1742
1743        # Use has_default to find the length of each row: for each
1744        # non-default item in a row, calculate the length that the row needs to
1745        # have to include that item; and then take the max of those values
1746        # (across each row).
1747        has_nondefault = math_ops.logical_not(has_default)
1748        has_nondefault = math_ops.cast(has_nondefault, row_splits_dtype)
1749        length_for_nondefault_value = (
1750            has_nondefault *
1751            array_ops.expand_dims(math_ops.range(1, ncols + 1), 0))
1752        lengths = math_ops.reduce_max(length_for_nondefault_value, axis=1)
1753
1754      if lengths is not None:
1755        # If we have lengths (either directly supplied, or computed from
1756        # paddings), then use those to construct splits; and then use masking
1757        # to get the corresponding values.
1758        lengths = ragged_util.convert_to_int_tensor(lengths, "lengths",
1759                                                    row_splits_dtype)
1760        lengths.shape.assert_has_rank(1)
1761        lengths = math_ops.minimum(lengths, ncols)
1762        lengths = math_ops.maximum(lengths, 0)
1763        limits = math_ops.cumsum(lengths)
1764        splits = array_ops.concat(
1765            [array_ops.zeros([1], row_splits_dtype), limits], axis=0)
1766        mask = array_ops.sequence_mask(lengths, maxlen=ncols)
1767        values = array_ops.boolean_mask(tensor, mask)
1768        return cls.from_row_splits(values, splits, validate=False)
1769
1770      # If neither padding nor lengths were specified, then create a splits
1771      # vector that contains no default values, and reshape the input tensor
1772      # to form the values for the RaggedTensor.
1773      values_shape = array_ops.concat(
1774          [[input_shape[0] * input_shape[1]], input_shape[2:]], axis=0)
1775      values = array_ops.reshape(tensor, values_shape)
1776      const_nrows = tensor_shape.dimension_at_index(tensor.shape, 0).value
1777      const_ncols = tensor_shape.dimension_at_index(tensor.shape, 1).value
1778      if const_nrows is not None:
1779        nrows = constant_op.constant(const_nrows, row_splits_dtype)
1780      else:
1781        nrows = input_shape[0]
1782      if const_ncols is not None:
1783        ncols = constant_op.constant(const_ncols, row_splits_dtype)
1784      else:
1785        ncols = input_shape[1]
1786      return RaggedTensor.from_uniform_row_length(
1787          values=values, uniform_row_length=ncols, nrows=nrows, validate=False)
1788
1789  def to_tensor(self, default_value=None, name=None, shape=None):
1790    """Converts this `RaggedTensor` into a `tf.Tensor`.
1791
1792    If `shape` is specified, then the result is padded and/or truncated to
1793    the specified shape.
1794
1795    Examples:
1796
1797    >>> rt = tf.ragged.constant([[9, 8, 7], [], [6, 5], [4]])
1798    >>> print(rt.to_tensor())
1799    tf.Tensor(
1800        [[9 8 7] [0 0 0] [6 5 0] [4 0 0]], shape=(4, 3), dtype=int32)
1801    >>> print(rt.to_tensor(shape=[5, 2]))
1802    tf.Tensor(
1803        [[9 8] [0 0] [6 5] [4 0] [0 0]], shape=(5, 2), dtype=int32)
1804
1805    Args:
1806      default_value: Value to set for indices not specified in `self`. Defaults
1807        to zero.  `default_value` must be broadcastable to
1808        `self.shape[self.ragged_rank + 1:]`.
1809      name: A name prefix for the returned tensors (optional).
1810      shape: The shape of the resulting dense tensor.  In particular,
1811        `result.shape[i]` is `shape[i]` (if `shape[i]` is not None), or
1812        `self.bounding_shape(i)` (otherwise).`shape.rank` must be `None` or
1813        equal to `self.rank`.
1814
1815    Returns:
1816      A `Tensor` with shape `ragged.bounding_shape(self)` and the
1817      values specified by the non-empty values in `self`.  Empty values are
1818      assigned `default_value`.
1819    """
1820    with ops.name_scope(name, "RaggedToTensor", [self, default_value, shape]):
1821      if default_value is not None:
1822        default_value = ops.convert_to_tensor(
1823            default_value, name="default_value", dtype=self.dtype)
1824      type_tensor_pairs = _get_row_partition_type_tensor_pairs(self)
1825      row_partition_types = [x[0] for x in type_tensor_pairs]
1826      row_partition_tensors = [x[1] for x in type_tensor_pairs]
1827      if default_value is None:
1828        default_value = array_ops.zeros((), self.dtype)
1829
1830      if (isinstance(shape, (list, tuple)) and
1831          any(isinstance(v, ops.Tensor) for v in shape) and
1832          all(isinstance(v, (int, ops.Tensor)) for v in shape)):
1833        shape = array_ops.stack(shape)
1834
1835      shape_tensor = _shape_as_tensor(shape, row_partition_tensors[0].dtype)
1836      tensor = gen_ragged_conversion_ops.ragged_tensor_to_tensor(
1837          shape=shape_tensor,
1838          values=self.flat_values,
1839          default_value=default_value,
1840          row_partition_types=row_partition_types,
1841          row_partition_tensors=row_partition_tensors)
1842
1843      ragged_shape = self.shape
1844
1845      if ragged_shape.rank is not None and not isinstance(shape, ops.Tensor):
1846        # Merged self.shape and shape, favoring the second one as it takes
1847        # into account potential padding added to the output.
1848        shape = tensor_shape.as_shape(shape)
1849        if shape.rank is None:
1850          output_shape = ragged_shape
1851        else:
1852          # At this point we can assume that hshape.rank == ragged_shape.rank
1853          # because otherwise it would have failed earlier.
1854          output_shape = [
1855              s1 if s1 is not None else s2
1856              for (s1, s2) in zip(shape.as_list(), ragged_shape.as_list())
1857          ]
1858        tensor.set_shape(output_shape)
1859
1860      return tensor
1861
1862  @classmethod
1863  @dispatch.add_dispatch_support
1864  def from_sparse(cls, st_input, name=None, row_splits_dtype=dtypes.int64):
1865    """Converts a 2D `tf.sparse.SparseTensor` to a `RaggedTensor`.
1866
1867    Each row of the `output` `RaggedTensor` will contain the explicit values
1868    from the same row in `st_input`.  `st_input` must be ragged-right.  If not
1869    it is not ragged-right, then an error will be generated.
1870
1871    Example:
1872
1873    >>> indices = [[0, 0], [0, 1], [0, 2], [1, 0], [3, 0]]
1874    >>> st = tf.sparse.SparseTensor(indices=indices,
1875    ...                             values=[1, 2, 3, 4, 5],
1876    ...                             dense_shape=[4, 3])
1877    >>> tf.RaggedTensor.from_sparse(st).to_list()
1878    [[1, 2, 3], [4], [], [5]]
1879
1880    Currently, only two-dimensional `SparseTensors` are supported.
1881
1882    Args:
1883      st_input: The sparse tensor to convert.  Must have rank 2.
1884      name: A name prefix for the returned tensors (optional).
1885      row_splits_dtype: `dtype` for the returned `RaggedTensor`'s `row_splits`
1886        tensor.  One of `tf.int32` or `tf.int64`.
1887
1888    Returns:
1889      A `RaggedTensor` with the same values as `st_input`.
1890      `output.ragged_rank = rank(st_input) - 1`.
1891      `output.shape = [st_input.dense_shape[0], None]`.
1892    Raises:
1893      ValueError: If the number of dimensions in `st_input` is not known
1894        statically, or is not two.
1895    """
1896    row_splits_dtype = dtypes.as_dtype(row_splits_dtype)
1897    if not sparse_tensor.is_sparse(st_input):
1898      raise TypeError(f"Argument `st_input` must be of type SparseTensor, but "
1899                      f"is of type {type(st_input).__name__}.")
1900    with ops.name_scope(name, "RaggedFromSparse", [st_input]):
1901      st_input = sparse_tensor.convert_to_tensor_or_sparse_tensor(
1902          st_input, name="st_input")
1903
1904      if st_input.dense_shape.shape.ndims is None:
1905        static_rank_from_dense_shape = None
1906      else:
1907        static_rank_from_dense_shape = st_input.dense_shape.shape.dims[0].value
1908
1909      if st_input.indices.shape.ndims is None:
1910        static_rank_from_indices = None
1911      else:
1912        static_rank_from_indices = st_input.indices.shape.dims[1].value
1913
1914      if static_rank_from_dense_shape != 2 and static_rank_from_indices != 2:
1915        raise ValueError("rank(st_input) must be 2.")
1916
1917      with ops.control_dependencies(
1918          _assert_sparse_indices_are_ragged_right(st_input.indices)):
1919        # Treat sparse row indices as segment ids to generate a splits tensor
1920        # thta we can pair with the sparse tensor values.  (Ignore sparse column
1921        # indices.)
1922        segment_ids = math_ops.cast(st_input.indices[:, 0], row_splits_dtype)
1923        num_segments = math_ops.cast(st_input.dense_shape[0], row_splits_dtype)
1924        return cls.from_value_rowids(
1925            st_input.values, segment_ids, num_segments, validate=False)
1926
1927  def to_sparse(self, name=None):
1928    """Converts this `RaggedTensor` into a `tf.sparse.SparseTensor`.
1929
1930    Example:
1931
1932    >>> rt = tf.ragged.constant([[1, 2, 3], [4], [], [5, 6]])
1933    >>> print(rt.to_sparse())
1934    SparseTensor(indices=tf.Tensor(
1935                     [[0 0] [0 1] [0 2] [1 0] [3 0] [3 1]],
1936                     shape=(6, 2), dtype=int64),
1937                 values=tf.Tensor([1 2 3 4 5 6], shape=(6,), dtype=int32),
1938                 dense_shape=tf.Tensor([4 3], shape=(2,), dtype=int64))
1939
1940    Args:
1941      name: A name prefix for the returned tensors (optional).
1942
1943    Returns:
1944      A SparseTensor with the same values as `self`.
1945    """
1946    with ops.name_scope(name, "RaggedToSparse", [self]):
1947      result = gen_ragged_conversion_ops.ragged_tensor_to_sparse(
1948          self.nested_row_splits, self.flat_values, name=name)
1949      return sparse_tensor.SparseTensor(result.sparse_indices,
1950                                        result.sparse_values,
1951                                        result.sparse_dense_shape)
1952
1953  @classmethod
1954  def _from_variant(cls,
1955                    variant,
1956                    dtype,
1957                    output_ragged_rank,
1958                    input_ragged_rank=None,
1959                    row_splits_dtype=dtypes.int64,
1960                    name=None):
1961    """Converts a `variant` Tensor into a `RaggedTensor`.
1962
1963    The input `variant` could be a scalar, meaning it encodes a single
1964    `RaggedTensor` with ragged_rank `output_ragged_rank`. Alternatively it could
1965    have an arbitrary rank, in which case each element is decoded into a
1966    `RaggedTensor` with ragged_rank `input_ragged_rank` and these are then
1967    stacked according to the input shape to output a single `RaggedTensor`
1968    with ragged_rank `output_ragged_rank`. If `input_ragged_rank` is not
1969    provided, it is inferred dynamically as `output_ragged_rank` -
1970    `rank(variant)`. If `input_ragged_rank` is provided, the following must be
1971    true: `output_ragged_rank` = `input_ragged_rank` + `rank(variant)`.
1972
1973    Example:
1974
1975    >>> rt = tf.ragged.constant([[0], [1, 2]])
1976    >>> et = rt._to_variant()
1977    >>> stacked_et = tf.stack([et, et])
1978    >>> tf.RaggedTensor._from_variant(  # scalar input.
1979    ...     et, dtype=tf.int32, output_ragged_rank=1).to_list()
1980    [[0], [1, 2]]
1981    >>> tf.RaggedTensor._from_variant(  # batched input.
1982    ...     stacked_et, dtype=tf.int32, output_ragged_rank=2).to_list()
1983    [[[0], [1, 2]], [[0], [1, 2]]]
1984
1985    Args:
1986      variant: A `variant` Tensor representing an encoded (possibly
1987        nested-batched) `RaggedTensor`.
1988      dtype: The dtype of the encoded `RaggedTensor`.
1989      output_ragged_rank: The expected ragged rank of the output `RaggedTensor`.
1990      input_ragged_rank: The ragged rank of each encoded `RaggedTensor`. This is
1991        optional and inferred dynamically if not provided.
1992      row_splits_dtype: `dtype` for the RaggedTensor's `row_splits` tensor. One
1993        of `tf.int32` or `tf.int64`.
1994      name: A name prefix for the returned tensors (optional).
1995
1996    Returns:
1997      A `RaggedTensor` of dtype `dtype` and ragged rank `output_ragged_rank`.
1998
1999    Raises:
2000      ValueError: If the input rank is known, `input_ragged_rank` is provided
2001          and `output_ragged_rank` = `input_ragged_rank` + `rank(variant)` does
2002          not hold.
2003    """
2004    variant = ops.convert_to_tensor(
2005        variant, name="variant", dtype=dtypes.variant)
2006    if (variant.shape.ndims is not None and input_ragged_rank is not None and
2007        output_ragged_rank != input_ragged_rank + variant.shape.ndims):
2008      raise ValueError(
2009          f"Argument `output_ragged_rank` ({output_ragged_rank}) must be equal "
2010          f"to `input_ragged_rank` + `variant.shape.ndims` "
2011          f"({input_ragged_rank} + {variant.shape.ndims}).")
2012    input_ragged_rank = -1 if input_ragged_rank is None else input_ragged_rank
2013    with ops.name_scope(
2014        name, "RaggedFromVariant",
2015        [variant, dtype, input_ragged_rank, output_ragged_rank]):
2016      result = gen_ragged_conversion_ops.ragged_tensor_from_variant(
2017          variant, input_ragged_rank, max(output_ragged_rank, 0), dtype,
2018          row_splits_dtype, name)
2019      return cls.from_nested_row_splits(
2020          result.output_dense_values,
2021          result.output_nested_splits,
2022          validate=False)
2023
2024  def _to_variant(self, batched_input=False, name=None):
2025    """Converts this `RaggedTensor` into a `variant` Tensor.
2026
2027    If `batched_input` is `True`, then the `RaggedTensor` is unbatched along the
2028    zero-th dimension, each component `RaggedTensor` is encoded into a scalar
2029    `variant` Tensor, and these are stacked to return a 1-D `variant` Tensor.
2030    If `batched_input` is `False`, then the `RaggedTensor` is encoded as is and
2031    a scalar `variant` Tensor is returned.
2032
2033    Example:
2034    >>> rt = tf.ragged.constant([[[0]], [[1]], [[2]]])
2035    >>> rt._to_variant().shape.as_list()
2036    []
2037    >>> rt._to_variant(batched_input=True).shape.as_list()
2038    [3]
2039
2040    Args:
2041      batched_input: If `True`, the `RaggedTensor` is unbatched and converted to
2042        a `variant` vector. Set to `False` by default.
2043      name: A name prefix for the returned tensors (optional).
2044
2045    Returns:
2046      A `variant` Tensor that encodes this `RaggedTensor`.
2047    """
2048    with ops.name_scope(name, "RaggedToVariant", [self, batched_input]):
2049      return gen_ragged_conversion_ops.ragged_tensor_to_variant(
2050          self.nested_row_splits, self.flat_values, batched_input, name)
2051
2052  #=============================================================================
2053  # String Encoding
2054  #=============================================================================
2055  def __repr__(self):
2056    if self._is_eager():
2057      # The np.array2string in _formatter provides a separator argument, but
2058      # doesn't handle recursive calls correctly. The np.printoptions handles
2059      # recursive calls correctly, but doesn't provide a separator argument.
2060      # Combines them together to print elements separated by comma, while
2061      # avoiding the redundant array prefixes and dtypes. For example,
2062      # the value of tf.ragged.constant([[1, 2], [3, 4]]) will look like
2063      #
2064      # [[1, 2],
2065      #  [3, 4]]
2066      with np.printoptions(formatter={"all": _formatter}):
2067        value_text = _formatter(self.numpy())
2068      return f"<tf.RaggedTensor {value_text}>"
2069    else:
2070      return "tf.RaggedTensor(values=%s, row_splits=%s)" % (self.values,
2071                                                            self.row_splits)
2072
2073  #=============================================================================
2074  # Eager Execution Mode
2075  #=============================================================================
2076
2077  def numpy(self):
2078    """Returns a numpy `array` with the values for this `RaggedTensor`.
2079
2080    Requires that this `RaggedTensor` was constructed in eager execution mode.
2081
2082    Ragged dimensions are encoded using numpy `arrays` with `dtype=object` and
2083    `rank=1`, where each element is a single row.
2084
2085    #### Examples
2086
2087    In the following example, the value returned by `RaggedTensor.numpy()`
2088    contains three numpy `array` objects: one for each row (with `rank=1` and
2089    `dtype=int64`), and one to combine them (with `rank=1` and `dtype=object`):
2090
2091    >>> tf.ragged.constant([[1, 2, 3], [4, 5]], dtype=tf.int64).numpy()
2092    array([array([1, 2, 3]), array([4, 5])], dtype=object)
2093
2094    Uniform dimensions are encoded using multidimensional numpy `array`s.  In
2095    the following example, the value returned by `RaggedTensor.numpy()` contains
2096    a single numpy `array` object, with `rank=2` and `dtype=int64`:
2097
2098    >>> tf.ragged.constant([[1, 2, 3], [4, 5, 6]], dtype=tf.int64).numpy()
2099    array([[1, 2, 3], [4, 5, 6]])
2100
2101    Returns:
2102      A numpy `array`.
2103    """
2104    if not self._is_eager():
2105      raise ValueError("RaggedTensor.numpy() is only supported in eager mode.")
2106    values = self.values.numpy()
2107    splits = self.row_splits.numpy()
2108    rows = [values[splits[i]:splits[i + 1]] for i in range(len(splits) - 1)]
2109    if not rows:
2110      return np.zeros((0, 0) + values.shape[1:], dtype=values.dtype)
2111    # Note: if `rows` have ragged lengths, then they will be stored in a
2112    # np.ndarray with dtype=object and rank=1.  If they have uniform lengths,
2113    # they will be combined into a single np.ndarray with dtype=row.dtype and
2114    # rank=row.rank+1.
2115    #
2116    # Manually set dtype as numpy now complains when given ragged rows.
2117    has_variable_length_rows = any(len(row) != len(rows[0]) for row in rows)
2118    dtype = np.object_ if has_variable_length_rows else None
2119    return np.array(rows, dtype=dtype)
2120
2121  def to_list(self):
2122    """Returns a nested Python `list` with the values for this `RaggedTensor`.
2123
2124    Requires that `rt` was constructed in eager execution mode.
2125
2126    Returns:
2127      A nested Python `list`.
2128    """
2129    if not isinstance(self.row_splits, ops.EagerTensor):
2130      raise ValueError("to_list can only be used in eager mode.")
2131    row_splits = self.row_splits.numpy().tolist()
2132    values = self.values
2133
2134    if isinstance(values, RaggedTensor):
2135      return [
2136          values[row_splits[i]:row_splits[i + 1]].to_list()
2137          for i in range(len(row_splits) - 1)
2138      ]
2139    else:
2140      # Convert values to a Python list.
2141      if hasattr(values, "numpy"):
2142        values_as_list = values.numpy().tolist()
2143      elif hasattr(values, "to_list"):
2144        values_as_list = values.to_list()
2145      else:
2146        raise ValueError("values must be convertible to a list")
2147
2148      return [
2149          values_as_list[row_splits[i]:row_splits[i + 1]]
2150          for i in range(len(row_splits) - 1)
2151      ]
2152
2153  def _eager_value(self):
2154    """Returns a RaggedTensorValue for self.  Requires self._is_eager()=true."""
2155    value = self.flat_values.numpy()
2156    for row_splits in reversed(self.nested_row_splits):
2157      value = ragged_tensor_value.RaggedTensorValue(value, row_splits.numpy())
2158    return value
2159
2160  def _is_eager(self):
2161    """Returns True if values & row_splits Tensors are all `EagerTensor`s."""
2162    rt = self
2163    while isinstance(rt, RaggedTensor):
2164      if not isinstance(rt.row_splits, ops.EagerTensor):
2165        return False
2166      rt = rt.values
2167    return isinstance(rt, ops.EagerTensor)
2168
2169  #=============================================================================
2170  # Operators
2171  #=============================================================================
2172  # To avoid circular dependencies, we define stub methods for operators here,
2173  # and then override them when the ragged_operators module is imported.
2174
2175  def _overloaded_operator(name):  # pylint: disable=no-self-argument
2176
2177    def stub(*args, **kwargs):
2178      del args, kwargs
2179      raise ValueError(
2180          f"You must import 'tensorflow.python.ops.ragged.ragged_ops' "
2181          f"before using RaggedTensor.{name}.")
2182
2183    return stub
2184
2185  __getitem__ = _overloaded_operator("__getitem__")
2186  __ge__ = _overloaded_operator("__ge__")
2187  __gt__ = _overloaded_operator("__gt__")
2188  __le__ = _overloaded_operator("__le__")
2189  __lt__ = _overloaded_operator("__lt__")
2190  __and__ = _overloaded_operator("__and__")
2191  __rand__ = _overloaded_operator("__rand__")
2192  __invert__ = _overloaded_operator("__invert__")
2193  __ror__ = _overloaded_operator("__ror__")
2194  __or__ = _overloaded_operator("__or__")
2195  __xor__ = _overloaded_operator("__xor__")
2196  __rxor__ = _overloaded_operator("__rxor__")
2197  __abs__ = _overloaded_operator("__abs__")
2198  __add__ = _overloaded_operator("__add__")
2199  __radd__ = _overloaded_operator("__radd__")
2200  __div__ = _overloaded_operator("__div__")
2201  __rdiv__ = _overloaded_operator("__rdiv__")
2202  __floordiv__ = _overloaded_operator("__floordiv__")
2203  __rfloordiv__ = _overloaded_operator("__rfloordiv__")
2204  __mod__ = _overloaded_operator("__mod__")
2205  __rmod__ = _overloaded_operator("__rmod__")
2206  __mul__ = _overloaded_operator("__mul__")
2207  __rmul__ = _overloaded_operator("__rmul__")
2208  __neg__ = _overloaded_operator("__neg__")
2209  __pow__ = _overloaded_operator("__pow__")
2210  __rpow__ = _overloaded_operator("__rpow__")
2211  __sub__ = _overloaded_operator("__sub__")
2212  __rsub__ = _overloaded_operator("__rsub__")
2213  __truediv__ = _overloaded_operator("__truediv__")
2214  __rtruediv__ = _overloaded_operator("__rtruediv__")
2215  del _overloaded_operator
2216
2217  #=============================================================================
2218  # Name Scope
2219  #=============================================================================
2220
2221  # This private function is used by ops.name_scope to ensure that all of the
2222  # input tensors for the scope belong to the same graph.  Defining this means
2223  # that you may include `RaggedTensor` objects in the name_scope `values`
2224  # list.
2225  def _as_graph_element(self):
2226    """Convert `self` to a graph element."""
2227    values = self.values
2228    while isinstance(values, RaggedTensor):
2229      values = values.values
2230    return values
2231
2232  #=============================================================================
2233  # Composite Tensor
2234  #=============================================================================
2235
2236  @property
2237  def _type_spec(self):
2238    return RaggedTensorSpec.from_value(self)
2239
2240  def _shape_invariant_to_type_spec(self, shape):
2241    return RaggedTensorSpec(shape, self.dtype, self.ragged_rank,
2242                            self.row_splits.dtype)
2243
2244  def consumers(self):
2245    return self._consumers()
2246
2247  __composite_gradient__ = (
2248      composite_tensor_gradient.WithValuesCompositeTensorGradient())
2249
2250
2251def is_ragged(value):
2252  """Returns true if `value` is a ragged tensor or ragged tensor value."""
2253  return isinstance(value,
2254                    (RaggedTensor, ragged_tensor_value.RaggedTensorValue))
2255
2256
2257def match_row_splits_dtypes(*tensors, **kwargs):
2258  """Return a copy of `tensors` with row_splits all having the same dtype.
2259
2260  Args:
2261    *tensors: A list of Tensors or RaggedTensors.
2262    **kwargs: If 'return_dtype=True', then return a tuple (dtype, tensors),
2263      where `dtype` is the data type used by row-splits, and `tensors` is the
2264      converted list of `Tensors` and `RaggedTensors`.
2265
2266  Returns:
2267    The converted list of `Tensors` and `RaggedTensors`.
2268  """
2269  return_dtype = kwargs.pop("return_dtype", False)
2270  if kwargs:
2271    raise ValueError(f"Unexpected keyword args {kwargs}.")
2272
2273  has_int32 = False
2274  has_int64 = False
2275  for tensor in tensors:
2276    if isinstance(tensor, RaggedTensor):
2277      if tensor.row_splits.dtype == dtypes.int32:
2278        has_int32 = True
2279      else:
2280        has_int64 = True
2281
2282  if has_int32 and has_int64:
2283    if not ragged_config.auto_cast_partition_dtype():
2284      raise ValueError("Input RaggedTensors have mismatched row_splits dtypes; "
2285                       "use RaggedTensor.with_row_splits_dtype() to convert "
2286                       "them to compatible dtypes.")
2287    dtype = dtypes.int64
2288    tensors = tuple(
2289        t.with_row_splits_dtype(dtypes.int64) if isinstance(t, RaggedTensor
2290                                                           ) else t
2291        for t in tensors)
2292
2293  elif has_int32:
2294    dtype = dtypes.int32
2295  else:
2296    dtype = dtypes.int64
2297
2298  if return_dtype:
2299    return (dtype, tensors)
2300  else:
2301    return tensors
2302
2303
2304#===============================================================================
2305# RaggedTensorSpec
2306#===============================================================================
2307@tf_export("RaggedTensorSpec")
2308@type_spec.register("tf.RaggedTensorSpec")
2309class RaggedTensorSpec(type_spec.BatchableTypeSpec):
2310  """Type specification for a `tf.RaggedTensor`."""
2311
2312  __slots__ = [
2313      "_shape", "_dtype", "_ragged_rank", "_row_splits_dtype",
2314      "_flat_values_spec"
2315  ]
2316
2317  @property
2318  def dtype(self):
2319    """The `tf.dtypes.DType` specified by this type for the RaggedTensor.
2320
2321    Examples:
2322
2323    >>> rt = tf.ragged.constant([["a"], ["b", "c"]], dtype=tf.string)
2324    >>> tf.type_spec_from_value(rt).dtype
2325    tf.string
2326
2327    Returns:
2328      A `tf.dtypes.DType` of the values in the RaggedTensor.
2329    """
2330    return self._dtype
2331
2332  @property
2333  def shape(self):
2334    """The statically known shape of the RaggedTensor.
2335
2336    Examples:
2337
2338    >>> rt = tf.ragged.constant([[0], [1, 2]])
2339    >>> tf.type_spec_from_value(rt).shape
2340    TensorShape([2, None])
2341
2342    >>> rt = tf.ragged.constant([[[0, 1]], [[1, 2], [3, 4]]], ragged_rank=1)
2343    >>> tf.type_spec_from_value(rt).shape
2344    TensorShape([2, None, 2])
2345
2346    Returns:
2347      A `tf.TensorShape` containing the statically known shape of the
2348      RaggedTensor. Ragged dimensions have a size of `None`.
2349    """
2350    return self._shape
2351
2352  @property
2353  def ragged_rank(self):
2354    """The number of times the RaggedTensor's flat_values is partitioned.
2355
2356    Defaults to `shape.ndims - 1`.
2357
2358    Examples:
2359
2360    >>> values = tf.ragged.constant([[1, 2, 3], [4], [5, 6], [7, 8, 9, 10]])
2361    >>> tf.type_spec_from_value(values).ragged_rank
2362    1
2363
2364    >>> rt1 = tf.RaggedTensor.from_uniform_row_length(values, 2)
2365    >>> tf.type_spec_from_value(rt1).ragged_rank
2366    2
2367
2368    Returns:
2369      A Python `int` indicating the number of times the underlying `flat_values`
2370      Tensor has been partitioned to add a new dimension.
2371      I.e., `tf.rank(rt) = tf.rank(rt.flat_values) + rt.ragged_rank`.
2372    """
2373    return self._ragged_rank
2374
2375  @property
2376  def row_splits_dtype(self):
2377    """The `tf.dtypes.DType` of the RaggedTensor's `row_splits`.
2378
2379    Examples:
2380
2381    >>> rt = tf.ragged.constant([[1, 2, 3], [4]], row_splits_dtype=tf.int64)
2382    >>> tf.type_spec_from_value(rt).row_splits_dtype
2383    tf.int64
2384
2385    Returns:
2386      A `tf.dtypes.DType` for the RaggedTensor's `row_splits` tensor. One
2387      of `tf.int32` or `tf.int64`.
2388    """
2389    return self._row_splits_dtype
2390
2391  @property
2392  def flat_values_spec(self):
2393    """The `TypeSpec` of the flat_values of RaggedTensor.
2394
2395    Returns:
2396      - The TypeSpec of flat_values.
2397      - None when the flat_values is a Tensor.
2398    """
2399    return self._flat_values_spec
2400
2401  @property
2402  def value_type(self):
2403    return RaggedTensor if self._ragged_rank > 0 else ops.Tensor
2404
2405  def __init__(self,
2406               shape=None,
2407               dtype=dtypes.float32,
2408               ragged_rank=None,
2409               row_splits_dtype=dtypes.int64,
2410               flat_values_spec=None):
2411    """Constructs a type specification for a `tf.RaggedTensor`.
2412
2413    Args:
2414      shape: The shape of the RaggedTensor, or `None` to allow any shape.  If a
2415        shape is specified, then all ragged dimensions must have size `None`.
2416      dtype: `tf.DType` of values in the RaggedTensor.
2417      ragged_rank: Python integer, the number of times the RaggedTensor's
2418        flat_values is partitioned.  Defaults to `shape.ndims - 1`.
2419      row_splits_dtype: `dtype` for the RaggedTensor's `row_splits` tensor. One
2420        of `tf.int32` or `tf.int64`.
2421      flat_values_spec: TypeSpec for flat_value of the RaggedTensor. It shall be
2422        provided when the flat_values is a CompositeTensor rather then Tensor.
2423        If both `dtype` and `flat_values_spec` and  are provided, `dtype` must
2424        be the same as `flat_values_spec.dtype`. (experimental)
2425    """
2426    self._shape = tensor_shape.as_shape(shape)
2427    self._row_splits_dtype = dtypes.as_dtype(row_splits_dtype)
2428    if flat_values_spec is not None:
2429      if dtype is None:
2430        dtype = flat_values_spec.dtype
2431      elif dtype != flat_values_spec.dtype:
2432        raise ValueError("dtype must be the same as flat_values_spec.dtype")
2433    elif dtype is None:
2434      raise ValueError(
2435          "At least one of dtype or flat_values_spec must be provided")
2436    self._dtype = dtypes.as_dtype(dtype)
2437    self._flat_values_spec = flat_values_spec
2438
2439    rank = self._shape.ndims
2440    if ragged_rank is None:
2441      if rank is None:
2442        raise ValueError("Must specify ragged_rank or "
2443                         "a shape with a known rank.")
2444      ragged_rank = rank - 1
2445    self._ragged_rank = ragged_rank
2446    if not isinstance(self._ragged_rank, int):
2447      raise TypeError(f"Argument `ragged_rank` must be an int. "
2448                      f"Received {ragged_rank}.")
2449
2450    if rank is not None:
2451      if ragged_rank >= rank:
2452        raise ValueError(f"Argument `ragged_rank` ({ragged_rank}) must be less "
2453                         f"than rank ({rank}).")
2454
2455  def is_compatible_with(self, spec_or_value):
2456    # RaggedTensor with ragged_rank 0 can be compatible with raw flat_values.
2457    if self._ragged_rank == 0:
2458      if self._flat_values_spec is None:
2459        if isinstance(spec_or_value, (ops.Tensor, tensor_spec.TensorSpec)):
2460          return tensor_spec.TensorSpec(
2461              self._shape, self._dtype).is_compatible_with(spec_or_value)
2462      elif not isinstance(spec_or_value, (RaggedTensor, RaggedTensorSpec)):
2463        return self._flat_values_spec.is_compatible_with(spec_or_value)
2464    return super(RaggedTensorSpec, self).is_compatible_with(spec_or_value)
2465
2466  def _serialize(self):
2467    if self._flat_values_spec is None:
2468      return (self._shape, self._dtype, self._ragged_rank,
2469              self._row_splits_dtype)
2470    else:
2471      return (self._shape, self._dtype, self._ragged_rank,
2472              self._row_splits_dtype, self._flat_values_spec)
2473
2474  @property
2475  def _component_specs(self):
2476    if self._ragged_rank <= 0:
2477      if self._flat_values_spec is not None:
2478        return [self._flat_values_spec]
2479      else:
2480        return [tensor_spec.TensorSpec(self._shape, self._dtype)]
2481
2482    flat_values_spec = self._flat_values_spec
2483    if flat_values_spec is None:
2484      flat_values_shape = tensor_shape.TensorShape([None]).concatenate(
2485          self._shape[self._ragged_rank + 1:])
2486      flat_values_spec = tensor_spec.TensorSpec(flat_values_shape, self._dtype)
2487    outer_dim = tensor_shape.dimension_at_index(self._shape, 0)
2488    outer_splits_shape = [None if outer_dim is None else outer_dim + 1]
2489    inner_splits_spec = tensor_spec.TensorSpec([None], self._row_splits_dtype)
2490
2491    specs = ([
2492        flat_values_spec,
2493        tensor_spec.TensorSpec(outer_splits_shape, self._row_splits_dtype)
2494    ] + [inner_splits_spec for _ in range(self._ragged_rank - 1)])
2495    return specs
2496
2497  def _to_components(self, value):
2498    if is_ragged(value):
2499      return [value.flat_values] + list(value.nested_row_splits)
2500    else:
2501      return [value]
2502
2503  def _from_components(self, tensor_list):
2504    result = tensor_list[0]
2505    if (all(isinstance(t, np.ndarray) for t in tensor_list) and
2506        not tf2.enabled()):
2507      for row_splits in reversed(tensor_list[1:]):
2508        result = ragged_tensor_value.RaggedTensorValue(result, row_splits)
2509    else:
2510      if isinstance(tensor_list[0], np.ndarray):
2511        tensor_list = [ops.convert_to_tensor(t) for t in tensor_list]
2512        result = tensor_list[0]
2513      for row_splits in reversed(tensor_list[1:]):
2514        result = RaggedTensor(
2515            result,
2516            RowPartition.from_row_splits(row_splits, validate=False),
2517            internal=True)
2518    if self._shape.ndims is not None:
2519      if isinstance(result, RaggedTensor):
2520        result._set_shape(self._shape)  # pylint: disable=protected-access
2521        # TODO(xjun): MaskedTensor doesn't implement set_shape.
2522        if self.flat_values_spec is not None and hasattr(result.flat_values,
2523                                                         "set_shape"):
2524          result.flat_values.set_shape(self.flat_values_spec.shape)
2525      elif isinstance(result, ops.Tensor):
2526        result.set_shape(self._shape)
2527    return result
2528
2529  # The RaggedTensorSpec tensor_list encoding uses to/from_variant ops
2530  # to (un)box the component tensors in a way that allows for batching &
2531  # unbatching.
2532  @property
2533  def _flat_tensor_specs(self):
2534    # NOTE(mishragaurav): The default flat shape of a boxed `RaggedTensor` is
2535    # `[]` (scalar), but a `RaggedTensorSpec` can also represent a batch of
2536    # boxed `RaggedTensor` objects with shape `(...)` (and batches of batches,
2537    # etc.), so the flat shape must be unknown.
2538    return [tensor_spec.TensorSpec(None, dtypes.variant)]
2539
2540  def _to_tensor_list(self, value):
2541    # TODO(edloper): Update gen_ragged_conversion_ops that convert to and
2542    # from variant to include all of the row-partitioning tensors.
2543    if self._flat_values_spec is not None:
2544      raise ValueError("Customized value_type is not supported.")
2545    if isinstance(value, RaggedTensor):
2546      if value.ragged_rank != self._ragged_rank:
2547        raise ValueError(
2548            f"Ragged rank of value {value.ragged_rank} does not match "
2549            f"ragged rank of type {self._ragged_rank}.")
2550      # pylint: disable=protected-access
2551      return [value._to_variant(batched_input=False)]
2552    else:
2553      if self._ragged_rank > 0:
2554        raise ValueError(
2555            f"Expected a RaggedTensor if ragged rank={self._ragged_rank}"
2556            f" but got {type(value).__name__}."
2557        )
2558      return [
2559          gen_ragged_conversion_ops.ragged_tensor_to_variant(
2560              (), value, batched_input=False)
2561      ]
2562
2563  def _to_batched_tensor_list(self, value):
2564    if self._flat_values_spec is not None:
2565      raise ValueError("Customized value_type is not supported.")
2566    if isinstance(value, RaggedTensor):
2567      if value.ragged_rank != self._ragged_rank:
2568        raise ValueError(
2569            f"Ragged rank of value {value.ragged_rank} does not match "
2570            f"ragged rank of type {self._ragged_rank}.")
2571      # pylint: disable=protected-access
2572      return [value._to_variant(batched_input=True)]
2573    else:
2574      if self._ragged_rank > 0:
2575        raise ValueError(
2576            f"Expected a RaggedTensor if ragged rank={self._ragged_rank}"
2577            f" but got {type(value).__name__}."
2578        )
2579      return [
2580          gen_ragged_conversion_ops.ragged_tensor_to_variant(
2581              rt_nested_splits=(), rt_dense_values=value, batched_input=True)
2582      ]
2583
2584  def _from_compatible_tensor_list(self, tensor_list):
2585    if self._flat_values_spec is not None:
2586      raise ValueError("Customized value_type is not supported.")
2587    result = RaggedTensor._from_variant(  # pylint: disable=protected-access
2588        tensor_list[0],
2589        dtype=self._dtype,
2590        row_splits_dtype=self._row_splits_dtype,
2591        output_ragged_rank=self._ragged_rank)
2592    if self._shape.ndims is not None:
2593      if isinstance(result, RaggedTensor):
2594        result._set_shape(self._shape)  # pylint: disable=protected-access
2595        # TODO(xjun): MaskedTensor doesn't implement set_shape.
2596        if self.flat_values_spec is not None and hasattr(self.flat_values,
2597                                                         "set_shape"):
2598          result.flat_values.set_shape(self.flat_values_spec.shape)
2599      else:
2600        result.set_shape(self._shape)
2601    return result
2602
2603  def _batch(self, batch_size):
2604    if self._flat_values_spec is not None:
2605      raise ValueError("Customized value_type is not supported.")
2606    return RaggedTensorSpec(
2607        tensor_shape.TensorShape([batch_size]).concatenate(self._shape),
2608        self._dtype, self._ragged_rank + 1, self._row_splits_dtype)
2609
2610  def _unbatch(self):
2611    if self._flat_values_spec is not None:
2612      raise ValueError("Customized value_type is not supported.")
2613    # Note: Negative ragged_rank is allowed here because the dataset could be
2614    # subsequently batched again. If ragged_rank > 1, assume row_splits_dtype is
2615    # consistent. Errors are handled in
2616    # RaggedTensorSpec._from_compatible_tensor_list()
2617    return RaggedTensorSpec(self._shape[1:], self._dtype, self._ragged_rank - 1,
2618                            self._row_splits_dtype)
2619
2620  def _to_legacy_output_types(self):
2621    return self._dtype
2622
2623  def _to_legacy_output_shapes(self):
2624    return self._shape
2625
2626  def _to_legacy_output_classes(self):
2627    return self
2628
2629  @classmethod
2630  def from_value(cls, value):
2631    if (isinstance(value, ragged_tensor_value.RaggedTensorValue) or
2632        isinstance(value.flat_values, ops.Tensor)):
2633      return cls(
2634          shape=value.shape,
2635          dtype=value.values.dtype,
2636          ragged_rank=value.ragged_rank,
2637          row_splits_dtype=value.row_splits.dtype)
2638    else:
2639      flat_values_spec = type_spec.type_spec_from_value(value.flat_values)
2640      # Relax shape[0] to None, as it is connected to dynamic ragged shapes.
2641      flat_values_spec = flat_values_spec._unbatch()._batch(None)  # pylint: disable=protected-access
2642      return cls(
2643          shape=value.shape,
2644          dtype=value.values.dtype,
2645          ragged_rank=value.ragged_rank,
2646          row_splits_dtype=value.row_splits.dtype,
2647          flat_values_spec=flat_values_spec)
2648
2649
2650type_spec.register_type_spec_from_value_converter(
2651    ragged_tensor_value.RaggedTensorValue, RaggedTensorSpec.from_value)
2652
2653
2654#===============================================================================
2655# Convert value -> tensor
2656#===============================================================================
2657def convert_to_tensor_or_ragged_tensor(value,
2658                                       dtype=None,
2659                                       preferred_dtype=None,
2660                                       name=None):
2661  """Converts value to a `RaggedTensor` or `Tensor`.
2662
2663  * If `value` is a `RaggedTensor`, then return it as-is.
2664  * If `value` is a `RaggedTensorValue`, return a corresponding constant
2665    `RaggedTensor`.
2666  * Otherwise, use `convert_to_tensor` to convert `value` to a `Tensor`.
2667
2668  Args:
2669    value: A `RaggedTensor`, a `RaggedTensorValue`, or an object whose type has
2670      a registered `Tensor` conversion function.
2671    dtype: Optional element type for the returned tensor.  If missing the type
2672      is inferred from the type of `value`.
2673    preferred_dtype: Optional element type for the returned tensor, used when
2674      dtype is None.  This argument has no effect if `value` is already a
2675      tensor, or when conversion is not possible.
2676    name: Optional name to use if a new `Tensor` is created.
2677
2678  Returns:
2679    A `Tensor` or `RaggedTensor`.
2680  """
2681  if isinstance(value, RaggedTensor):
2682    if dtype and not dtype.is_compatible_with(value.dtype):
2683      raise ValueError(f"Tensor conversion requested dtype {dtype.name} for "
2684                       f"RaggedTensor with dtype {value.dtype.name}: {value}.")
2685    return value
2686  elif isinstance(value, ragged_tensor_value.RaggedTensorValue):
2687    with ops.name_scope(name, "ConvertToTensorOrRaggedTensor", []):
2688      flat_values = ops.convert_to_tensor(
2689          value=value.flat_values,
2690          dtype=dtype,
2691          dtype_hint=preferred_dtype,
2692          name="flat_values")
2693      return RaggedTensor.from_nested_row_splits(
2694          flat_values, value.nested_row_splits, validate=False)
2695  else:
2696    return ops.convert_to_tensor_v2_with_dispatch(
2697        value=value, dtype=dtype, dtype_hint=preferred_dtype, name=name)
2698
2699
2700def _convert_to_ragged_tensor_values(value):
2701  """Converts value to supported RaggedTensor value.
2702
2703  * If `value` is an object of supported value type, then return it as-is.
2704  * Otherwise convert it to Tensor or RaggedTensor.
2705
2706  Args:
2707    value: An object of `Tensor`, `RaggedTensor` or registerred RaggedTensor
2708      value types, or an object whose type has a registered `Tensor` conversion
2709      function.
2710
2711  Returns:
2712    An object of `Tensor`, `RaggedTensor` or registerred RaggedTensor
2713    value types
2714  """
2715  if _is_supported_ragged_values_type(value):
2716    return value
2717  else:
2718    return convert_to_tensor_or_ragged_tensor(value, name="values")
2719
2720
2721#===============================================================================
2722# Register RaggedTensor for use with session.run.
2723#===============================================================================
2724def _ragged_tensor_value_from_components(components):
2725  components = list(components)
2726  value = components.pop()
2727  while components:
2728    value = ragged_tensor_value.RaggedTensorValue(value, components.pop())
2729  return value
2730
2731
2732def _ragged_tensor_session_fetch(rt):
2733  components = rt.nested_row_splits + (rt.flat_values,)
2734  return (components, _ragged_tensor_value_from_components)
2735
2736
2737def _ragged_tensor_session_feed(feed_key, feed_val):
2738  key_components = feed_key.nested_row_splits + (feed_key.flat_values,)
2739  val_components = feed_val.nested_row_splits + (feed_val.flat_values,)
2740  return zip(key_components, val_components)
2741
2742
2743def _ragged_tensor_session_feed_for_partial_run(feed_key):
2744  return feed_key.nested_row_splits + (feed_key.flat_values,)
2745
2746
2747session.register_session_run_conversion_functions(
2748    RaggedTensor, _ragged_tensor_session_fetch, _ragged_tensor_session_feed,
2749    _ragged_tensor_session_feed_for_partial_run)
2750
2751
2752#===============================================================================
2753# RaggedTensorType
2754#===============================================================================
2755class RaggedTensorType:
2756  """Encoding of a static type for a `RaggedTensor`.
2757
2758  Use this type to express/declare that an output must have the type of
2759  `RaggedTensor`.
2760  """
2761
2762  def __init__(self, dtype, ragged_rank, row_splits_dtype=dtypes.int64):
2763    """Initializes a RaggedTensorType object.
2764
2765    Args:
2766      dtype: data type of the `RaggedTensor`'s inner values.
2767      ragged_rank: ragged_rank of the declared `RaggedTensor`.
2768      row_splits_dtype: data type for the `RaggedTensor`'s row splits.
2769        One of: `tf.int32` or `tf.int64`.
2770    """
2771    row_splits_dtype = dtypes.as_dtype(row_splits_dtype)
2772    self._dtype = dtype
2773    self._ragged_rank = ragged_rank
2774    self._row_splits_dtype = row_splits_dtype
2775
2776  dtype = property(lambda self: self._dtype)
2777  ragged_rank = property(lambda self: self._ragged_rank)
2778  row_splits_dtype = property(lambda self: self._row_splits_dtype)
2779
2780  def __repr__(self):
2781    return "RaggedTensorType(%r, %r, %r)" % (self.dtype, self.ragged_rank,
2782                                             self.row_splits_dtype)
2783
2784
2785#===============================================================================
2786# Helper Functions
2787#===============================================================================
2788def _assert_sparse_indices_are_ragged_right(indices):
2789  """Checks that the given SparseTensor.indices tensor is ragged-right.
2790
2791  Example: `indices = [[0, 0], [0, 1], [2, 0], [3, 1]]` is not ragged right
2792  because the entry `[3, 1]` skips a cell.
2793
2794  Args:
2795    indices: The SparseTensor indices to check.
2796
2797  Returns:
2798    A list of control dependency op tensors.
2799  """
2800  index_prefix = indices[:, :-1]
2801  index_suffix = indices[:, -1]
2802
2803  # Check whether each index is starting a new row in the innermost dimension
2804  # (prefix[i] != prefix[i-1]) or continuing a row (prefix[i] == prefix[i-1]).
2805  # (Note: this skips the first index; we will check that separately below.)
2806  index_prefix_changed = math_ops.reduce_any(
2807      math_ops.not_equal(index_prefix[1:], index_prefix[:-1]), axis=1)
2808
2809  # Check two cases:
2810  #   * For indices that start a new row: index_suffix[i] must be zero.
2811  #   * For indices that continue a row: index_suffix[i] must be equal to
2812  #     index_suffix[i-1]+1.
2813  index_ok = array_ops.where(
2814      index_prefix_changed, math_ops.equal(index_suffix[1:], 0),
2815      math_ops.equal(index_suffix[1:], index_suffix[:-1] + 1))
2816
2817  # Also check that the very first index didn't skip any cells.  The first
2818  # index starts a new row (by definition), so its suffix should be zero.
2819  sparse_indices_are_ragged_right = math_ops.logical_and(
2820      math_ops.reduce_all(math_ops.equal(index_suffix[:1], 0)),
2821      math_ops.reduce_all(index_ok))
2822
2823  message = [
2824      "SparseTensor is not right-ragged", "SparseTensor.indices =", indices
2825  ]
2826  return [control_flow_ops.Assert(sparse_indices_are_ragged_right, message)]
2827
2828
2829@ops.RegisterGradient("RaggedTensorToSparse")
2830def _ragged_tensor_to_sparse_gradient(op, unused_sparse_indices_grad,
2831                                      sparse_values_grad,
2832                                      unused_sparse_shape_grad):
2833  """Gradient for RaggedTensorToSparse."""
2834  op_inputs_nested_row_splits = op.inputs[:-1]
2835  op_inputs_flat_values = op.inputs[-1]
2836
2837  # No gradient for the RaggedTensor's nested_row_splits.
2838  nested_row_splits_gradient = [None] * len(op_inputs_nested_row_splits)
2839
2840  # Gradient for the RaggedTensor's flat_values is formed by reshaping
2841  # the gradient for the SparseTensor's values.
2842  flat_values_shape = array_ops.shape(op_inputs_flat_values)
2843  flat_values_gradient = array_ops.reshape(sparse_values_grad,
2844                                           flat_values_shape)
2845
2846  return nested_row_splits_gradient + [flat_values_gradient]
2847
2848
2849def _assert_monotonic_increasing(tensor, message=None):
2850  return check_ops.assert_non_negative(
2851      tensor[1:] - tensor[:-1], message=message)
2852
2853
2854def _assert_zero(tensor, message=None):
2855  return check_ops.assert_equal(
2856      tensor, constant_op.constant(0, dtype=tensor.dtype), message=message)
2857
2858
2859def _nrows(tensor, out_type=dtypes.int32):
2860  if isinstance(tensor, RaggedTensor):
2861    return tensor.nrows(out_type=out_type)
2862  else:
2863    return array_ops.shape(tensor, out_type=out_type)[0]
2864
2865
2866def merge_dims(value, outer_axis, inner_axis):
2867  """Merges value[outer_axis...inner_axis] into a single dimension.
2868
2869  See `RaggedTensor.merge_dims()` for more details.  This helper differs from
2870  `RaggedTensor.merge_dims()` in that `value` may be a dense or ragged tensor.
2871
2872  Args:
2873    value: A `RaggedTensor` or `Tensor`
2874    outer_axis: `int`
2875    inner_axis: `int`
2876
2877  Returns:
2878    A flattened `RaggedTensor` or `Tensor`.
2879  """
2880  if outer_axis == inner_axis:
2881    return value
2882
2883  # Flatten outer dimensions of a RaggedTensor by just taking its values.
2884  while outer_axis == 0 and isinstance(value, RaggedTensor):
2885    value = value.values
2886    inner_axis -= 1
2887    if inner_axis == 0:
2888      return value
2889
2890  # Flatten non-Ragged tensors using tf.reshape().
2891  if not isinstance(value, RaggedTensor):
2892    if value.shape.is_fully_defined():
2893      old_shape = value.shape.as_list()
2894      new_shape = old_shape[:outer_axis] + [-1] + old_shape[inner_axis + 1:]
2895    else:
2896      old_shape = array_ops.shape(value)
2897      new_shape = array_ops.concat(
2898          [old_shape[:outer_axis], [-1], old_shape[inner_axis + 1:]], axis=0)
2899    return array_ops.reshape(value, new_shape)
2900
2901  # Handle outer_axis>1 via recursion.
2902  if outer_axis > 1:
2903    return value.with_values(
2904        merge_dims(value.values, outer_axis - 1, inner_axis - 1))
2905
2906  # At this point, we know outer_axis == 1, and value is a RaggedTensor.
2907  # So we need to flatten the values and build a corresponding splits tensor.
2908  new_values = value.values
2909  new_splits = value.row_splits
2910  for axis in range(outer_axis, inner_axis):
2911    if isinstance(new_values, RaggedTensor):
2912      # Flatten a single ragged dimension.
2913      new_splits = array_ops.gather(new_values.row_splits, new_splits)
2914      new_values = new_values.values
2915    else:
2916      # Flatten all remaining dense dimensions.
2917      shape_split = inner_axis - axis + 1
2918      if new_values.shape.is_fully_defined():
2919        old_shape = new_values.shape.as_list()
2920        new_shape = [-1] + old_shape[shape_split:]
2921        flat_size = _prod(old_shape[1:shape_split])
2922      else:
2923        old_shape = array_ops.shape(new_values)
2924        new_shape = array_ops.concat([[-1], old_shape[shape_split:]], axis=0)
2925        flat_size = math_ops.cast(
2926            math_ops.reduce_prod(old_shape[1:shape_split]), new_splits.dtype)
2927      new_values = array_ops.reshape(new_values, new_shape)
2928      new_splits = new_splits * flat_size
2929      break
2930  return RaggedTensor.from_row_splits(new_values, new_splits)
2931
2932
2933def _prod(lst):
2934  """Returns the product of the numbers in a list."""
2935  return functools.reduce(operator.mul, lst, 1)
2936
2937
2938def _get_row_partition_type_tensor_pairs_tail(partition):
2939  """Gets a row partition type tensor pair for the tail.
2940
2941  If value_rowid is defined, then it is used. Otherwise, row_splits
2942  are used.
2943
2944  Args:
2945    partition: a RowPartition.
2946
2947  Returns:
2948    A list of (row_partition_type, row_partition_tensor) pairs.
2949  """
2950  if partition._has_precomputed_value_rowids():  # pylint: disable=protected-access
2951    return ("VALUE_ROWIDS", partition.value_rowids())
2952  else:
2953    return ("ROW_SPLITS", partition.row_splits())
2954
2955
2956def _get_row_partition_type_tensor_pairs(rt_input):
2957  """Gets a list of the row partitions for rt_input.
2958
2959  If value_rowids are defined, then they are used. Otherwise, row_splits
2960  are used. If the outermost level has value_rowids defind, then nrows is
2961  also added.
2962
2963  Args:
2964    rt_input: a ragged tensor.
2965
2966  Returns:
2967    A list of (row_partition_type, row_partition_tensor) pairs.
2968  """
2969  partitions = rt_input._nested_row_partitions  # pylint: disable=protected-access
2970  tail = [_get_row_partition_type_tensor_pairs_tail(x) for x in partitions[1:]]
2971
2972  if partitions[0]._value_rowids is not None:  # pylint: disable=protected-access
2973    return [("FIRST_DIM_SIZE", partitions[0].nrows()),
2974            ("VALUE_ROWIDS", partitions[0].value_rowids())] + tail
2975  else:
2976    return [("ROW_SPLITS", partitions[0].row_splits())] + tail
2977
2978
2979def _shape_as_tensor(shape, dtype):
2980  """Takes shape and coerces it to a shape as a tensor.
2981
2982  If the object is already a tensor, simply passes it on (result is guaranteed
2983  to be int64 or int32, but not necessarily dtype).
2984  If not, creates a tensor of type dtype.
2985
2986  Result is either a scalar equal to -1 if the shape is unknown_rank.
2987  Otherwise, it is a vector, where unknown dimensions are represented with a
2988  value of -1.
2989
2990  In C++, see TensorShapeFromTensor for parsing shapes in kernels, and
2991  InferenceContext::MakeShapeFromShapeTensorTreatScalarAsUnknownShape, for
2992  use in the shape inference function.
2993
2994  Args:
2995    shape: input to coerce from TensorShape, Tensor, None, List[Optional[Int]],
2996      Tuple[Optional[Int]].
2997    dtype: tf.int64 or tf.int32
2998
2999  Returns:
3000    a scalar or vector tensor of dtype tf.int32 or tf.int64.
3001  """
3002  if dtype != dtypes.int64 and dtype != dtypes.int32:
3003    raise ValueError(f"Expected int64 or int32 for dtype: got {dtype}.")
3004
3005  if isinstance(shape, ops.Tensor):
3006    if shape.dtype != dtypes.int64 and shape.dtype != dtypes.int32:
3007      return math_ops.cast(shape, dtype)
3008    return shape
3009  shape = tensor_shape.as_shape(shape)
3010  if not shape:
3011    # Imply rank is unknown using a -1 scalar.
3012    return constant_op.constant(-1, dtype=dtype)
3013  shape = [(-1 if x is None else x) for x in shape.as_list()]
3014  # At this point, shape is List[Int].
3015  return constant_op.constant(shape, dtype=dtype)
3016
3017
3018def _nvals_uniform_row_length(values, uniform_row_length):
3019  """Get the number of values for uniform row length constructor."""
3020  const_nvals = tensor_shape.dimension_at_index(values.shape, 0).value
3021  if const_nvals is not None:
3022    nvals = constant_op.constant(const_nvals, uniform_row_length.dtype)
3023  elif isinstance(values, RaggedTensor):
3024    nvals = values.nrows(out_type=uniform_row_length.dtype)
3025  else:
3026    nvals = array_ops.shape(values, out_type=uniform_row_length.dtype)[0]
3027  return nvals
3028
3029
3030def _get_optional_partition_dtype(values):
3031  """Returns the partition dtype, or None if None exists."""
3032  if isinstance(values, RaggedTensor):
3033    # pylint: disable=protected-access
3034    return values._row_partition.dtype
3035  return None
3036
3037
3038_SUPPORTED_RAGGED_VALUE_TYPES = (ops.Tensor, RaggedTensor)
3039
3040
3041# TODO(edloper): Consider whether we should change the registry to be on
3042# TypeSpecs rather than ValueTypes.
3043def _add_supported_value_type(cls):
3044  """Register the `cls` as supported value type of RaggedTenosr.
3045
3046  The cls must be a subclass of CompositeTensor, and must support:
3047   - Spec:
3048     The Spec must be a `BatchableTypeSpec`
3049   - Properties:
3050     - x.shape
3051     - x.dtype
3052   - Methods:
3053     - x.__getitem__(idx) (method: returns a supported value type)
3054     - x.set_shape(shape)
3055   - Ops:
3056     - tf.shape(x) -- tf.shape(x)[0] must be a tf.Tensor.
3057     - tf.tile(x)
3058     - assert_rank_at_least(x)
3059     - tf.ones_like(x)
3060     - tf.gather(params=x, indices=Tensor)
3061     - tf.add(x, y)
3062     - tf.boolean_mask(x, ...)
3063     - @TODO(edloper): Complete this list
3064
3065   Note: the following RaggedTensor, RaggedTensorSpec methods & ops are not
3066   currently supported unless `rt.values` is a RaggedTensor or a tf.Tensor:
3067     - rt.to_tensor()
3068     - rt.to_sparse_tensor()
3069     - rt._to_variant()
3070     - rt._from_variant()
3071     - tf.ragged.cross([rt])
3072     - tf.gather(params=x, indices=rt)  # rt used for indices
3073     - RaggedTensorSpec methods:
3074       - _batch
3075       - _unbatch
3076       - _to_tensor_list
3077       - _to_batched_tensor_list
3078       - _from_compatible_tensor_list
3079
3080  Args:
3081    cls: The type to be added to supported value types.
3082  """
3083  if not issubclass(cls, composite_tensor.CompositeTensor):
3084    raise ValueError(f"cls ({cls}) must be a subclass of CompositeTensor.")
3085  if not hasattr(cls, "shape"):
3086    raise ValueError("cls must support the `shape` property.")
3087  if not hasattr(cls, "dtype"):
3088    raise ValueError("cls must support the `dtype` property.")
3089  global _SUPPORTED_RAGGED_VALUE_TYPES
3090  _SUPPORTED_RAGGED_VALUE_TYPES += (cls,)
3091
3092
3093def _is_supported_ragged_values_type(value):
3094  return isinstance(value, _SUPPORTED_RAGGED_VALUE_TYPES)
3095
3096
3097def _assert_is_supported_ragged_values_type(value):
3098  if not _is_supported_ragged_values_type(value):
3099    ok_types = ", ".join(cls.__name__ for cls in _SUPPORTED_RAGGED_VALUE_TYPES)
3100    raise TypeError(f"type(values) must be one of: {ok_types}, got {value}.")
3101
3102
3103def _formatter(x):
3104  """Separate Numpy array elements with comma."""
3105  if isinstance(x, np.ndarray):
3106    if x.size != 0:
3107      return np.array2string(x, separator=", ")
3108    else:
3109      # When x.size==0, np.array2string always returns `[]`.  This isn't always
3110      # what we want.  E.g., if `x.shape=[0, 3]`, then we want `[[], [], []]`.
3111      return repr(x.tolist())
3112  else:
3113    return str(x)
3114
3115# Type annotation indicating that a value is ragged.  Includes RaggedTensor
3116# as well as the (deprecated) RaggedTensorValue class from TF 1.x.
3117Ragged = typing.Union[RaggedTensor, ragged_tensor_value.RaggedTensorValue]
3118
3119# Type annotation indicating that a value is a ragged tensor, a dense tensor,
3120# or a value that can be converted to a tensor (e.g. np.array).
3121# TODO(edloper): Add Variable to TensorLike, and remove it from here.
3122RaggedOrDense = typing.Union[Ragged, core_types.TensorLike]
3123