• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Classes for storing ragged tensors and their values."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import functools
22import operator
23
24import numpy as np
25
26from tensorflow.python import tf2
27from tensorflow.python.client import session
28from tensorflow.python.framework import composite_tensor
29from tensorflow.python.framework import constant_op
30from tensorflow.python.framework import dtypes
31from tensorflow.python.framework import ops
32from tensorflow.python.framework import sparse_tensor
33from tensorflow.python.framework import tensor_shape
34from tensorflow.python.framework import tensor_spec
35from tensorflow.python.framework import tensor_util
36from tensorflow.python.framework import type_spec
37from tensorflow.python.ops import array_ops
38from tensorflow.python.ops import check_ops
39from tensorflow.python.ops import control_flow_ops
40from tensorflow.python.ops import gen_ragged_conversion_ops
41from tensorflow.python.ops import math_ops
42from tensorflow.python.ops.ragged import ragged_config
43from tensorflow.python.ops.ragged import ragged_tensor_value
44from tensorflow.python.ops.ragged import ragged_util
45from tensorflow.python.ops.ragged.row_partition import RowPartition
46from tensorflow.python.types import internal as internal_types
47from tensorflow.python.util import dispatch
48from tensorflow.python.util.tf_export import tf_export
49from tensorflow.tools.docs import doc_controls
50
51# pylint: disable=protected-access
52_convert_row_partition = RowPartition._convert_row_partition
53# pylint: enable=protected-access
54
55#===============================================================================
56# RaggedTensor
57#===============================================================================
58
59
60@tf_export("RaggedTensor")
61class RaggedTensor(composite_tensor.CompositeTensor,
62                   internal_types.NativeObject):
63  """Represents a ragged tensor.
64
65  A `RaggedTensor` is a tensor with one or more *ragged dimensions*, which are
66  dimensions whose slices may have different lengths.  For example, the inner
67  (column) dimension of `rt=[[3, 1, 4, 1], [], [5, 9, 2], [6], []]` is ragged,
68  since the column slices (`rt[0, :]`, ..., `rt[4, :]`) have different lengths.
69  Dimensions whose slices all have the same length are called *uniform
70  dimensions*.  The outermost dimension of a `RaggedTensor` is always uniform,
71  since it consists of a single slice (and so there is no possibility for
72  differing slice lengths).
73
74  The total number of dimensions in a `RaggedTensor` is called its *rank*,
75  and the number of ragged dimensions in a `RaggedTensor` is called its
76  *ragged-rank*.  A `RaggedTensor`'s ragged-rank is fixed at graph creation
77  time: it can't depend on the runtime values of `Tensor`s, and can't vary
78  dynamically for different session runs.
79
80  Note that the `__init__` constructor is private. Please use one of the
81  following methods to construct a `RaggedTensor`:
82
83  * `tf.RaggedTensor.from_row_lengths`
84  * `tf.RaggedTensor.from_value_rowids`
85  * `tf.RaggedTensor.from_row_splits`
86  * `tf.RaggedTensor.from_row_starts`
87  * `tf.RaggedTensor.from_row_limits`
88  * `tf.RaggedTensor.from_nested_row_splits`
89  * `tf.RaggedTensor.from_nested_row_lengths`
90  * `tf.RaggedTensor.from_nested_value_rowids`
91
92  ### Potentially Ragged Tensors
93
94  Many ops support both `Tensor`s and `RaggedTensor`s
95  (see [tf.ragged](https://www.tensorflow.org/api_docs/python/tf/ragged) for a
96  full listing). The term "potentially ragged tensor" may be used to refer to a
97  tensor that might be either a `Tensor` or a `RaggedTensor`.  The ragged-rank
98  of a `Tensor` is zero.
99
100  ### Documenting RaggedTensor Shapes
101
102  When documenting the shape of a RaggedTensor, ragged dimensions can be
103  indicated by enclosing them in parentheses.  For example, the shape of
104  a 3-D `RaggedTensor` that stores the fixed-size word embedding for each
105  word in a sentence, for each sentence in a batch, could be written as
106  `[num_sentences, (num_words), embedding_size]`.  The parentheses around
107  `(num_words)` indicate that dimension is ragged, and that the length
108  of each element list in that dimension may vary for each item.
109
110  ### Component Tensors
111
112  Internally, a `RaggedTensor` consists of a concatenated list of values that
113  are partitioned into variable-length rows.  In particular, each `RaggedTensor`
114  consists of:
115
116    * A `values` tensor, which concatenates the variable-length rows into a
117      flattened list.  For example, the `values` tensor for
118      `[[3, 1, 4, 1], [], [5, 9, 2], [6], []]` is `[3, 1, 4, 1, 5, 9, 2, 6]`.
119
120    * A `row_splits` vector, which indicates how those flattened values are
121      divided into rows.  In particular, the values for row `rt[i]` are stored
122      in the slice `rt.values[rt.row_splits[i]:rt.row_splits[i+1]]`.
123
124  Example:
125
126  >>> print(tf.RaggedTensor.from_row_splits(
127  ...       values=[3, 1, 4, 1, 5, 9, 2, 6],
128  ...       row_splits=[0, 4, 4, 7, 8, 8]))
129  <tf.RaggedTensor [[3, 1, 4, 1], [], [5, 9, 2], [6], []]>
130
131  ### Alternative Row-Partitioning Schemes
132
133  In addition to `row_splits`, ragged tensors provide support for five other
134  row-partitioning schemes:
135
136    * `row_lengths`: a vector with shape `[nrows]`, which specifies the length
137      of each row.
138
139    * `value_rowids` and `nrows`: `value_rowids` is a vector with shape
140      `[nvals]`, corresponding one-to-one with `values`, which specifies
141      each value's row index.  In particular, the row `rt[row]` consists of the
142      values `rt.values[j]` where `value_rowids[j]==row`.  `nrows` is an
143      integer scalar that specifies the number of rows in the
144      `RaggedTensor`. (`nrows` is used to indicate trailing empty rows.)
145
146    * `row_starts`: a vector with shape `[nrows]`, which specifies the start
147      offset of each row.  Equivalent to `row_splits[:-1]`.
148
149    * `row_limits`: a vector with shape `[nrows]`, which specifies the stop
150      offset of each row.  Equivalent to `row_splits[1:]`.
151
152    * `uniform_row_length`: A scalar tensor, specifying the length of every
153      row.  This row-partitioning scheme may only be used if all rows have
154      the same length.
155
156  Example: The following ragged tensors are equivalent, and all represent the
157  nested list `[[3, 1, 4, 1], [], [5, 9, 2], [6], []]`.
158
159  >>> values = [3, 1, 4, 1, 5, 9, 2, 6]
160  >>> RaggedTensor.from_row_splits(values, row_splits=[0, 4, 4, 7, 8, 8])
161  <tf.RaggedTensor [[3, 1, 4, 1], [], [5, 9, 2], [6], []]>
162  >>> RaggedTensor.from_row_lengths(values, row_lengths=[4, 0, 3, 1, 0])
163  <tf.RaggedTensor [[3, 1, 4, 1], [], [5, 9, 2], [6], []]>
164  >>> RaggedTensor.from_value_rowids(
165  ...     values, value_rowids=[0, 0, 0, 0, 2, 2, 2, 3], nrows=5)
166  <tf.RaggedTensor [[3, 1, 4, 1], [], [5, 9, 2], [6], []]>
167  >>> RaggedTensor.from_row_starts(values, row_starts=[0, 4, 4, 7, 8])
168  <tf.RaggedTensor [[3, 1, 4, 1], [], [5, 9, 2], [6], []]>
169  >>> RaggedTensor.from_row_limits(values, row_limits=[4, 4, 7, 8, 8])
170  <tf.RaggedTensor [[3, 1, 4, 1], [], [5, 9, 2], [6], []]>
171  >>> RaggedTensor.from_uniform_row_length(values, uniform_row_length=2)
172  <tf.RaggedTensor [[3, 1], [4, 1], [5, 9], [2, 6]]>
173
174  ### Multiple Ragged Dimensions
175
176  `RaggedTensor`s with multiple ragged dimensions can be defined by using
177  a nested `RaggedTensor` for the `values` tensor.  Each nested `RaggedTensor`
178  adds a single ragged dimension.
179
180  >>> inner_rt = RaggedTensor.from_row_splits(  # =rt1 from above
181  ...     values=[3, 1, 4, 1, 5, 9, 2, 6], row_splits=[0, 4, 4, 7, 8, 8])
182  >>> outer_rt = RaggedTensor.from_row_splits(
183  ...     values=inner_rt, row_splits=[0, 3, 3, 5])
184  >>> print(outer_rt.to_list())
185  [[[3, 1, 4, 1], [], [5, 9, 2]], [], [[6], []]]
186  >>> print(outer_rt.ragged_rank)
187  2
188
189  The factory function `RaggedTensor.from_nested_row_splits` may be used to
190  construct a `RaggedTensor` with multiple ragged dimensions directly, by
191  providing a list of `row_splits` tensors:
192
193  >>> RaggedTensor.from_nested_row_splits(
194  ...     flat_values=[3, 1, 4, 1, 5, 9, 2, 6],
195  ...     nested_row_splits=([0, 3, 3, 5], [0, 4, 4, 7, 8, 8])).to_list()
196  [[[3, 1, 4, 1], [], [5, 9, 2]], [], [[6], []]]
197
198  ### Uniform Inner Dimensions
199
200  `RaggedTensor`s with uniform inner dimensions can be defined
201  by using a multidimensional `Tensor` for `values`.
202
203  >>> rt = RaggedTensor.from_row_splits(values=tf.ones([5, 3], tf.int32),
204  ...                                   row_splits=[0, 2, 5])
205  >>> print(rt.to_list())
206  [[[1, 1, 1], [1, 1, 1]],
207   [[1, 1, 1], [1, 1, 1], [1, 1, 1]]]
208  >>> print(rt.shape)
209  (2, None, 3)
210
211  ### Uniform Outer Dimensions
212
213  `RaggedTensor`s with uniform outer dimensions can be defined by using
214  one or more `RaggedTensor` with a `uniform_row_length` row-partitioning
215  tensor.  For example, a `RaggedTensor` with shape `[2, 2, None]` can be
216  constructed with this method from a `RaggedTensor` values with shape
217  `[4, None]`:
218
219  >>> values = tf.ragged.constant([[1, 2, 3], [4], [5, 6], [7, 8, 9, 10]])
220  >>> print(values.shape)
221  (4, None)
222  >>> rt6 = tf.RaggedTensor.from_uniform_row_length(values, 2)
223  >>> print(rt6)
224  <tf.RaggedTensor [[[1, 2, 3], [4]], [[5, 6], [7, 8, 9, 10]]]>
225  >>> print(rt6.shape)
226  (2, 2, None)
227
228  Note that `rt6` only contains one ragged dimension (the innermost
229  dimension). In contrast, if `from_row_splits` is used to construct a similar
230  `RaggedTensor`, then that `RaggedTensor` will have two ragged dimensions:
231
232  >>> rt7 = tf.RaggedTensor.from_row_splits(values, [0, 2, 4])
233  >>> print(rt7.shape)
234  (2, None, None)
235
236  Uniform and ragged outer dimensions may be interleaved, meaning that a
237  tensor with any combination of ragged and uniform dimensions may be created.
238  For example, a RaggedTensor `t4` with shape `[3, None, 4, 8, None, 2]` could
239  be constructed as follows:
240
241  ```python
242  t0 = tf.zeros([1000, 2])                           # Shape:         [1000, 2]
243  t1 = RaggedTensor.from_row_lengths(t0, [...])      #           [160, None, 2]
244  t2 = RaggedTensor.from_uniform_row_length(t1, 8)   #         [20, 8, None, 2]
245  t3 = RaggedTensor.from_uniform_row_length(t2, 4)   #       [5, 4, 8, None, 2]
246  t4 = RaggedTensor.from_row_lengths(t3, [...])      # [3, None, 4, 8, None, 2]
247  ```
248
249  """
250
251  #=============================================================================
252  # Constructor (private)
253  #=============================================================================
254  @doc_controls.do_not_generate_docs
255  def __init__(self, values, row_partition, internal=False):
256    """Creates a `RaggedTensor` with a specified partitioning for `values`.
257
258    This constructor is private -- please use one of the following ops to
259    build `RaggedTensor`s:
260
261      * `tf.RaggedTensor.from_row_lengths`
262      * `tf.RaggedTensor.from_value_rowids`
263      * `tf.RaggedTensor.from_row_splits`
264      * `tf.RaggedTensor.from_row_starts`
265      * `tf.RaggedTensor.from_row_limits`
266      * `tf.RaggedTensor.from_nested_row_splits`
267      * `tf.RaggedTensor.from_nested_row_lengths`
268      * `tf.RaggedTensor.from_nested_value_rowids`
269
270    Args:
271      values: A potentially ragged tensor of any dtype and shape `[nvals, ...]`.
272      row_partition: A `RowPartition` object, representing the arrangement of
273        the lists at the top level.
274      internal: True if the constructor is being called by one of the factory
275        methods.  If false, an exception will be raised.
276
277    Raises:
278      ValueError: If internal = False. Note that this method is intended only
279                 for internal use.
280      TypeError: If values is not a `RaggedTensor` or `Tensor`, or
281                 row_partition is not a `RowPartition`.
282    """
283
284    if not internal:
285      raise ValueError("RaggedTensor constructor is private; please use one "
286                       "of the factory methods instead (e.g., "
287                       "RaggedTensor.from_row_lengths())")
288    _assert_is_supported_ragged_values_type(values)
289    if not isinstance(row_partition, RowPartition):
290      raise TypeError(f"Argument `row_partition` must be a RowPartition. "
291                      f"Received {row_partition}.")
292
293    # Validate shapes.
294    values.shape.with_rank_at_least(1)
295    if isinstance(values, RaggedTensor):
296      # pylint: disable=protected-access
297      assert row_partition.dtype == values._row_partition.dtype
298
299    self._values = values
300    self._row_partition = row_partition
301
302  #=============================================================================
303  # Factory Methods
304  #=============================================================================
305
306  @classmethod
307  def _from_row_partition(cls, values, row_partition, validate=True):
308    """Creates a `RaggedTensor` with a row partition.
309
310    This is used as a way for RaggedTensors to share row partitions.
311
312    The outer dimension of values must be equal to `partition.nvals()`.
313
314    Args:
315      values: A potentially ragged tensor.
316      row_partition: a `RowPartition`: can be shared between tensors.
317      validate: If true, then use assertions to check that the arguments form a
318        valid `RaggedTensor`.
319
320    Returns:
321      A `RaggedTensor`.  `result.rank = values.rank + 1`.
322      `result.ragged_rank = values.ragged_rank + 1`.
323
324    Raises:
325      ValueError: If partition.nvals() != _nrows(values)
326    """
327    if not isinstance(row_partition, RowPartition):
328      raise TypeError(f"Argument `row_partition` must be a RowPartition. "
329                      f"Received {row_partition}.")
330    if not isinstance(validate, bool):
331      raise TypeError(f"Argument `validate` must have type bool. "
332                      f"Received {validate}.")
333    values, row_partition = cls._convert_values_and_partition(
334        values, row_partition, "partition")
335    if row_partition.has_precomputed_value_rowids():
336      value_rowids_shape = row_partition.value_rowids().shape
337      values.shape[:1].assert_is_compatible_with(value_rowids_shape)
338    if validate:
339      msg = "Arguments to _from_row_partition do not form a valid RaggedTensor"
340      nvals = _nrows(values, row_partition.dtype)
341      checks = [
342          check_ops.assert_equal(
343              row_partition.nvals(out_type=row_partition.dtype),
344              nvals,
345              message=msg),
346      ]
347      if not isinstance(values, RaggedTensor):
348        checks.append(check_ops.assert_rank_at_least(values, 1))
349      row_partition = row_partition.with_dependencies(checks)
350    return cls(values=values, internal=True, row_partition=row_partition)
351
352  @classmethod
353  @dispatch.add_dispatch_support
354  def from_value_rowids(cls,
355                        values,
356                        value_rowids,
357                        nrows=None,
358                        name=None,
359                        validate=True):
360    """Creates a `RaggedTensor` with rows partitioned by `value_rowids`.
361
362    The returned `RaggedTensor` corresponds with the python list defined by:
363
364    ```python
365    result = [[values[i] for i in range(len(values)) if value_rowids[i] == row]
366              for row in range(nrows)]
367    ```
368
369    Args:
370      values: A potentially ragged tensor with shape `[nvals, ...]`.
371      value_rowids: A 1-D integer tensor with shape `[nvals]`, which corresponds
372        one-to-one with `values`, and specifies each value's row index.  Must be
373        nonnegative, and must be sorted in ascending order.
374      nrows: An integer scalar specifying the number of rows.  This should be
375        specified if the `RaggedTensor` may containing empty training rows. Must
376        be greater than `value_rowids[-1]` (or zero if `value_rowids` is empty).
377        Defaults to `value_rowids[-1]` (or zero if `value_rowids` is empty).
378      name: A name prefix for the RaggedTensor (optional).
379      validate: If true, then use assertions to check that the arguments form
380        a valid `RaggedTensor`.  Note: these assertions incur a runtime cost,
381          since they must be checked for each tensor value.
382
383    Returns:
384      A `RaggedTensor`.  `result.rank = values.rank + 1`.
385      `result.ragged_rank = values.ragged_rank + 1`.
386
387    Raises:
388      ValueError: If `nrows` is incompatible with `value_rowids`.
389
390    #### Example:
391
392    >>> print(tf.RaggedTensor.from_value_rowids(
393    ...     values=[3, 1, 4, 1, 5, 9, 2, 6],
394    ...     value_rowids=[0, 0, 0, 0, 2, 2, 2, 3],
395    ...     nrows=5))
396    <tf.RaggedTensor [[3, 1, 4, 1], [], [5, 9, 2], [6], []]>
397
398    """
399    if not isinstance(validate, bool):
400      raise TypeError(f"Argument `validate` must have type bool. "
401                      f"Received {validate}.")
402
403    with ops.name_scope(name, "RaggedFromValueRowIds",
404                        [values, value_rowids, nrows]):
405      row_partition = RowPartition.from_value_rowids(
406          value_rowids=value_rowids,
407          nrows=nrows,
408          validate=validate,
409          preferred_dtype=_get_optional_partition_dtype(values))
410      return cls._from_row_partition(values, row_partition, validate=validate)
411
412  @classmethod
413  @dispatch.add_dispatch_support
414  def from_row_splits(cls, values, row_splits, name=None, validate=True):
415    """Creates a `RaggedTensor` with rows partitioned by `row_splits`.
416
417    The returned `RaggedTensor` corresponds with the python list defined by:
418
419    ```python
420    result = [values[row_splits[i]:row_splits[i + 1]]
421              for i in range(len(row_splits) - 1)]
422    ```
423
424    Args:
425      values: A potentially ragged tensor with shape `[nvals, ...]`.
426      row_splits: A 1-D integer tensor with shape `[nrows+1]`.  Must not be
427        empty, and must be sorted in ascending order.  `row_splits[0]` must be
428        zero and `row_splits[-1]` must be `nvals`.
429      name: A name prefix for the RaggedTensor (optional).
430      validate: If true, then use assertions to check that the arguments form
431        a valid `RaggedTensor`.  Note: these assertions incur a runtime cost,
432          since they must be checked for each tensor value.
433
434    Returns:
435      A `RaggedTensor`.  `result.rank = values.rank + 1`.
436      `result.ragged_rank = values.ragged_rank + 1`.
437
438    Raises:
439      ValueError: If `row_splits` is an empty list.
440
441    #### Example:
442
443    >>> print(tf.RaggedTensor.from_row_splits(
444    ...     values=[3, 1, 4, 1, 5, 9, 2, 6],
445    ...     row_splits=[0, 4, 4, 7, 8, 8]))
446    <tf.RaggedTensor [[3, 1, 4, 1], [], [5, 9, 2], [6], []]>
447
448    """
449    if not isinstance(validate, bool):
450      raise TypeError(f"Argument `validate` must have type bool. "
451                      f"Received {validate}.")
452
453    with ops.name_scope(name, "RaggedFromRowSplits", [values, row_splits]):
454      row_partition = RowPartition.from_row_splits(
455          row_splits=row_splits,
456          validate=validate,
457          preferred_dtype=_get_optional_partition_dtype(values))
458      return cls._from_row_partition(values, row_partition, validate=validate)
459
460  @classmethod
461  @dispatch.add_dispatch_support
462  def from_row_lengths(cls, values, row_lengths, name=None, validate=True):
463    """Creates a `RaggedTensor` with rows partitioned by `row_lengths`.
464
465    The returned `RaggedTensor` corresponds with the python list defined by:
466
467    ```python
468    result = [[values.pop(0) for i in range(length)]
469              for length in row_lengths]
470    ```
471
472    Args:
473      values: A potentially ragged tensor with shape `[nvals, ...]`.
474      row_lengths: A 1-D integer tensor with shape `[nrows]`.  Must be
475        nonnegative.  `sum(row_lengths)` must be `nvals`.
476      name: A name prefix for the RaggedTensor (optional).
477      validate: If true, then use assertions to check that the arguments form
478        a valid `RaggedTensor`.  Note: these assertions incur a runtime cost,
479          since they must be checked for each tensor value.
480
481    Returns:
482      A `RaggedTensor`.  `result.rank = values.rank + 1`.
483      `result.ragged_rank = values.ragged_rank + 1`.
484
485    #### Example:
486
487    >>> print(tf.RaggedTensor.from_row_lengths(
488    ...     values=[3, 1, 4, 1, 5, 9, 2, 6],
489    ...     row_lengths=[4, 0, 3, 1, 0]))
490    <tf.RaggedTensor [[3, 1, 4, 1], [], [5, 9, 2], [6], []]>
491
492    """
493    if not isinstance(validate, bool):
494      raise TypeError(f"Argument `validate` must have type bool. "
495                      f"Received {validate}.")
496
497    with ops.name_scope(name, "RaggedFromRowLengths", [values, row_lengths]):
498      row_partition = RowPartition.from_row_lengths(
499          row_lengths=row_lengths,
500          validate=validate,
501          preferred_dtype=_get_optional_partition_dtype(values))
502      return cls._from_row_partition(values, row_partition, validate=validate)
503
504  @classmethod
505  @dispatch.add_dispatch_support
506  def from_row_starts(cls, values, row_starts, name=None, validate=True):
507    """Creates a `RaggedTensor` with rows partitioned by `row_starts`.
508
509    Equivalent to: `from_row_splits(values, concat([row_starts, nvals]))`.
510
511    Args:
512      values: A potentially ragged tensor with shape `[nvals, ...]`.
513      row_starts: A 1-D integer tensor with shape `[nrows]`.  Must be
514        nonnegative and sorted in ascending order.  If `nrows>0`, then
515        `row_starts[0]` must be zero.
516      name: A name prefix for the RaggedTensor (optional).
517      validate: If true, then use assertions to check that the arguments form
518        a valid `RaggedTensor`.  Note: these assertions incur a runtime cost,
519          since they must be checked for each tensor value.
520
521    Returns:
522      A `RaggedTensor`.  `result.rank = values.rank + 1`.
523      `result.ragged_rank = values.ragged_rank + 1`.
524
525    #### Example:
526
527    >>> print(tf.RaggedTensor.from_row_starts(
528    ...     values=[3, 1, 4, 1, 5, 9, 2, 6],
529    ...     row_starts=[0, 4, 4, 7, 8]))
530    <tf.RaggedTensor [[3, 1, 4, 1], [], [5, 9, 2], [6], []]>
531
532    """
533    if not isinstance(validate, bool):
534      raise TypeError(f"Argument `validate` must have type bool. "
535                      f"Received {validate}.")
536    with ops.name_scope(name, "RaggedFromRowStarts", [values, row_starts]):
537      values = _convert_to_ragged_tensor_values(values)
538      row_partition = RowPartition.from_row_starts(
539          row_starts=row_starts,
540          nvals=_nrows(values),
541          validate=validate,
542          preferred_dtype=_get_optional_partition_dtype(values))
543      return cls._from_row_partition(values, row_partition, validate=validate)
544
545  @classmethod
546  @dispatch.add_dispatch_support
547  def from_row_limits(cls, values, row_limits, name=None, validate=True):
548    """Creates a `RaggedTensor` with rows partitioned by `row_limits`.
549
550    Equivalent to: `from_row_splits(values, concat([0, row_limits]))`.
551
552    Args:
553      values: A potentially ragged tensor with shape `[nvals, ...]`.
554      row_limits: A 1-D integer tensor with shape `[nrows]`.  Must be sorted in
555        ascending order.  If `nrows>0`, then `row_limits[-1]` must be `nvals`.
556      name: A name prefix for the RaggedTensor (optional).
557      validate: If true, then use assertions to check that the arguments form
558        a valid `RaggedTensor`.  Note: these assertions incur a runtime cost,
559          since they must be checked for each tensor value.
560
561    Returns:
562      A `RaggedTensor`.  `result.rank = values.rank + 1`.
563      `result.ragged_rank = values.ragged_rank + 1`.
564
565    #### Example:
566
567    >>> print(tf.RaggedTensor.from_row_limits(
568    ...     values=[3, 1, 4, 1, 5, 9, 2, 6],
569    ...     row_limits=[4, 4, 7, 8, 8]))
570    <tf.RaggedTensor [[3, 1, 4, 1], [], [5, 9, 2], [6], []]>
571
572    """
573    if not isinstance(validate, bool):
574      raise TypeError(f"Argument `validate` must have type bool. "
575                      f"Received {validate}.")
576    with ops.name_scope(name, "RaggedFromRowLimits", [values, row_limits]):
577      row_partition = RowPartition.from_row_limits(
578          row_limits=row_limits,
579          validate=validate,
580          preferred_dtype=_get_optional_partition_dtype(values))
581      return cls._from_row_partition(values, row_partition, validate=validate)
582
583  @classmethod
584  @dispatch.add_dispatch_support
585  def from_uniform_row_length(cls,
586                              values,
587                              uniform_row_length,
588                              nrows=None,
589                              validate=True,
590                              name=None):
591    """Creates a `RaggedTensor` with rows partitioned by `uniform_row_length`.
592
593    This method can be used to create `RaggedTensor`s with multiple uniform
594    outer dimensions.  For example, a `RaggedTensor` with shape `[2, 2, None]`
595    can be constructed with this method from a `RaggedTensor` values with shape
596    `[4, None]`:
597
598    >>> values = tf.ragged.constant([[1, 2, 3], [4], [5, 6], [7, 8, 9, 10]])
599    >>> print(values.shape)
600    (4, None)
601    >>> rt1 = tf.RaggedTensor.from_uniform_row_length(values, 2)
602    >>> print(rt1)
603    <tf.RaggedTensor [[[1, 2, 3], [4]], [[5, 6], [7, 8, 9, 10]]]>
604    >>> print(rt1.shape)
605    (2, 2, None)
606
607    Note that `rt1` only contains one ragged dimension (the innermost
608    dimension). In contrast, if `from_row_splits` is used to construct a similar
609    `RaggedTensor`, then that `RaggedTensor` will have two ragged dimensions:
610
611    >>> rt2 = tf.RaggedTensor.from_row_splits(values, [0, 2, 4])
612    >>> print(rt2.shape)
613    (2, None, None)
614
615    Args:
616      values: A potentially ragged tensor with shape `[nvals, ...]`.
617      uniform_row_length: A scalar integer tensor.  Must be nonnegative. The
618        size of the outer axis of `values` must be evenly divisible by
619        `uniform_row_length`.
620      nrows: The number of rows in the constructed RaggedTensor.  If not
621        specified, then it defaults to `nvals/uniform_row_length` (or `0` if
622        `uniform_row_length==0`).  `nrows` only needs to be specified if
623        `uniform_row_length` might be zero.  `uniform_row_length*nrows` must be
624        `nvals`.
625      validate: If true, then use assertions to check that the arguments form
626        a valid `RaggedTensor`.  Note: these assertions incur a runtime cost,
627          since they must be checked for each tensor value.
628      name: A name prefix for the RaggedTensor (optional).
629
630    Returns:
631      A `RaggedTensor` that corresponds with the python list defined by:
632
633      ```python
634      result = [[values.pop(0) for i in range(uniform_row_length)]
635                for _ in range(nrows)]
636      ```
637
638      `result.rank = values.rank + 1`.
639      `result.ragged_rank = values.ragged_rank + 1`.
640    """
641    if not isinstance(validate, bool):
642      raise TypeError(f"Argument `validate` must have type bool. "
643                      f"Received {validate}.")
644    with ops.name_scope(name, "RaggedFromUniformRowLength",
645                        [values, uniform_row_length, nrows]):
646      values = _convert_to_ragged_tensor_values(values)
647      uniform_row_length = _convert_row_partition(
648          uniform_row_length, "UniformRowLength",
649          _get_optional_partition_dtype(values))
650      nvals = _nvals_uniform_row_length(values, uniform_row_length)
651      row_partition = RowPartition.from_uniform_row_length(
652          uniform_row_length=uniform_row_length,
653          nvals=nvals,
654          nrows=nrows,
655          validate=validate,
656          preferred_dtype=_get_optional_partition_dtype(values))
657      return cls._from_row_partition(values, row_partition, validate=validate)
658
659  @classmethod
660  @dispatch.add_dispatch_support
661  def from_nested_value_rowids(cls,
662                               flat_values,
663                               nested_value_rowids,
664                               nested_nrows=None,
665                               name=None,
666                               validate=True):
667    """Creates a `RaggedTensor` from a nested list of `value_rowids` tensors.
668
669    Equivalent to:
670
671    ```python
672    result = flat_values
673    for (rowids, nrows) in reversed(zip(nested_value_rowids, nested_nrows)):
674      result = from_value_rowids(result, rowids, nrows)
675    ```
676
677    Args:
678      flat_values: A potentially ragged tensor.
679      nested_value_rowids: A list of 1-D integer tensors.  The `i`th tensor is
680        used as the `value_rowids` for the `i`th ragged dimension.
681      nested_nrows: A list of integer scalars.  The `i`th scalar is used as the
682        `nrows` for the `i`th ragged dimension.
683      name: A name prefix for the RaggedTensor (optional).
684      validate: If true, then use assertions to check that the arguments form
685        a valid `RaggedTensor`.  Note: these assertions incur a runtime cost,
686          since they must be checked for each tensor value.
687
688    Returns:
689      A `RaggedTensor` (or `flat_values` if `nested_value_rowids` is empty).
690
691    Raises:
692      ValueError: If `len(nested_values_rowids) != len(nested_nrows)`.
693    """
694    if not isinstance(validate, bool):
695      raise TypeError(f"Argument `validate` must have type bool. "
696                      f"Received {validate}.")
697    if isinstance(nested_value_rowids, ops.Tensor):
698      raise TypeError(f"Argument `nested_value_rowids` must be a list of "
699                      f"Tensors. Received {nested_value_rowids}.")
700    if nested_nrows is None:
701      nested_nrows = [None] * len(nested_value_rowids)
702    else:
703      if isinstance(nested_nrows, ops.Tensor):
704        raise TypeError(f"Argument `nested_nrows` must be a list of "
705                        f"Tensors. Received {nested_nrows}.")
706      if len(nested_nrows) != len(nested_value_rowids):
707        raise ValueError(
708            f"Argument `nested_nrows` must have the same length as "
709            f"argument `nested_value_rowids`. len(nested_nrows) = "
710            f"{len(nested_nrows)} vs. len(nested_values_rowids) = "
711            f"{len(nested_value_rowids)}.")
712
713    with ops.name_scope(name, "RaggedFromNestedValueRowIds", [flat_values] +
714                        list(nested_value_rowids) + list(nested_nrows)):
715      result = flat_values
716      for value_rowids, nrows in reversed(
717          list(zip(nested_value_rowids, nested_nrows))):
718        result = cls.from_value_rowids(
719            result, value_rowids, nrows, validate=validate)
720      return result
721
722  @classmethod
723  @dispatch.add_dispatch_support
724  def from_nested_row_splits(cls,
725                             flat_values,
726                             nested_row_splits,
727                             name=None,
728                             validate=True):
729    """Creates a `RaggedTensor` from a nested list of `row_splits` tensors.
730
731    Equivalent to:
732
733    ```python
734    result = flat_values
735    for row_splits in reversed(nested_row_splits):
736      result = from_row_splits(result, row_splits)
737    ```
738
739    Args:
740      flat_values: A potentially ragged tensor.
741      nested_row_splits: A list of 1-D integer tensors.  The `i`th tensor is
742        used as the `row_splits` for the `i`th ragged dimension.
743      name: A name prefix for the RaggedTensor (optional).
744      validate: If true, then use assertions to check that the arguments form
745        a valid `RaggedTensor`.  Note: these assertions incur a runtime cost,
746          since they must be checked for each tensor value.
747
748    Returns:
749      A `RaggedTensor` (or `flat_values` if `nested_row_splits` is empty).
750    """
751    if not isinstance(validate, bool):
752      raise TypeError(f"Argument `validate` must have type bool. "
753                      f"Received {validate}.")
754    if isinstance(nested_row_splits, ops.Tensor):
755      raise TypeError(f"Argument `nested_row_splits` must be a list of "
756                      f"Tensors. Received {nested_row_splits}.")
757    with ops.name_scope(name, "RaggedFromNestedRowSplits",
758                        [flat_values] + list(nested_row_splits)):
759      result = flat_values
760      for splits in reversed(nested_row_splits):
761        result = cls.from_row_splits(result, splits, validate=validate)
762      return result
763
764  @classmethod
765  @dispatch.add_dispatch_support
766  def from_nested_row_lengths(cls,
767                              flat_values,
768                              nested_row_lengths,
769                              name=None,
770                              validate=True):
771    """Creates a `RaggedTensor` from a nested list of `row_lengths` tensors.
772
773    Equivalent to:
774
775    ```python
776    result = flat_values
777    for row_lengths in reversed(nested_row_lengths):
778      result = from_row_lengths(result, row_lengths)
779    ```
780
781    Args:
782      flat_values: A potentially ragged tensor.
783      nested_row_lengths: A list of 1-D integer tensors.  The `i`th tensor is
784        used as the `row_lengths` for the `i`th ragged dimension.
785      name: A name prefix for the RaggedTensor (optional).
786      validate: If true, then use assertions to check that the arguments form
787        a valid `RaggedTensor`.  Note: these assertions incur a runtime cost,
788          since they must be checked for each tensor value.
789
790    Returns:
791      A `RaggedTensor` (or `flat_values` if `nested_row_lengths` is empty).
792    """
793    if not isinstance(validate, bool):
794      raise TypeError(f"Argument `validate` must have type bool. "
795                      f"Received {validate}.")
796    if isinstance(nested_row_lengths, ops.Tensor):
797      raise TypeError(f"Argument `nested_row_lengths` must be a list of "
798                      f"Tensors. Received {nested_row_lengths}.")
799    with ops.name_scope(name, "RaggedFromNestedRowlengths",
800                        [flat_values] + list(nested_row_lengths)):
801      result = flat_values
802      for lengths in reversed(nested_row_lengths):
803        result = cls.from_row_lengths(result, lengths, validate=validate)
804      return result
805
806  @classmethod
807  def _from_nested_row_partitions(cls,
808                                  flat_values,
809                                  nested_row_partitions,
810                                  name=None,
811                                  validate=True):
812    """Creates a `RaggedTensor` from a nested list of row partitions.
813
814    Equivalent to:
815
816    ```python
817    result = flat_values
818    for row_partition in reversed(nested_row_partitions):
819      result = _from_row_partition(result, row_partition)
820    ```
821
822    Args:
823      flat_values: A potentially ragged tensor.
824      nested_row_partitions: A list of row partitions.  The `i`th element is
825        used as the row partition for the `i`th ragged dimension.
826      name: A name prefix for the RaggedTensor (optional).
827      validate: If true, then use assertions to check that the arguments form
828        a valid `RaggedTensor`.  Note: these assertions incur a runtime cost,
829          since they must be checked for each tensor value.
830
831    Returns:
832      A `RaggedTensor` (or `flat_values` if `nested_row_lengths` is empty).
833    """
834    if not isinstance(validate, bool):
835      raise TypeError(f"Argument `validate` must have type bool. "
836                      f"Received {validate}.")
837    if isinstance(nested_row_partitions, RowPartition):
838      raise TypeError(f"Argument `nested_row_partitions` must be a list of "
839                      f"RowPartitions. Received {nested_row_partitions}.")
840    if isinstance(nested_row_partitions, ops.Tensor):
841      raise TypeError(f"Argument `nested_row_partitions` must be a list of "
842                      f"RowPartitions. Received {nested_row_partitions}.")
843    with ops.name_scope(name, "RaggedFromNestedRowPartitions",
844                        [flat_values] + list(nested_row_partitions)):
845      result = flat_values
846      for partition in reversed(nested_row_partitions):
847        result = cls._from_row_partition(result, partition, validate=validate)
848      return result
849
850  @classmethod
851  def _convert_values_and_partition(cls, values, row_partition, name):
852    """Converts `values` and `partition` to Tensors.
853
854    If `values` is a `RaggedTensor`, then converts `values` and `partition`
855    to have compatible row-partitioning dtypes.  In particular, if any of the
856    row partitioning tensors are `int64`, then all of the other row
857    partitioning tensors wil be cast to `int64` (if auto_cast_partition_dtype()
858    is true) or an error will be raised (if auto_cast_partition_dtype() is
859    false).
860
861    Args:
862      values: The `values` for the `RaggedTensor` being constructed.
863      row_partition: A RowPartition object for the `RaggedTensor` being
864        constructed.
865      name: The name of the RowPartition object.
866
867    Returns:
868      A tuple (values, partition).
869    """
870    if not isinstance(row_partition, RowPartition):
871      raise TypeError(f"Argument `row_partition` must be a RowPartition. "
872                      f"Received {row_partition}.")
873    if isinstance(values, RaggedTensor):
874      # pylint: disable=protected-access
875      if values._row_partition.dtype != row_partition.dtype:
876        if not ragged_config.auto_cast_partition_dtype():
877          # pylint: disable=protected-access
878          # TODO(edloper): get rid of the `name` parameter.
879          raise ValueError(
880              f"Argument `row_partition` of RaggedTensor with name: {name} "
881              f"must have same dtype as Argument `values`. "
882              f"({row_partition.dtype} vs. {values._row_partition.dtype}).")
883        values = values.with_row_splits_dtype(row_partition.dtype)
884    else:
885      values = _convert_to_ragged_tensor_values(values)
886
887    return (values, row_partition)
888
889  #=============================================================================
890  # Accessors
891  #=============================================================================
892
893  @property
894  def dtype(self):
895    """The `DType` of values in this tensor."""
896    return self._values.dtype
897
898  @property
899  def shape(self):
900    """The statically known shape of this ragged tensor.
901
902    Returns:
903      A `TensorShape` containing the statically known shape of this ragged
904      tensor.  Ragged dimensions have a size of `None`.
905
906    Examples:
907
908    >>> tf.ragged.constant([[0], [1, 2]]).shape
909    TensorShape([2, None])
910
911    >>> tf.ragged.constant([[[0, 1]], [[1, 2], [3, 4]]], ragged_rank=1).shape
912    TensorShape([2, None, 2])
913
914    """
915    nrows = self._row_partition.static_nrows
916    ncols = self._row_partition.static_uniform_row_length
917    value_shape = self._values.shape[1:]
918    return tensor_shape.TensorShape([nrows, ncols]).concatenate(value_shape)
919
920  def get_shape(self):
921    """The statically known shape of this ragged tensor.
922
923    Returns:
924      A `TensorShape` containing the statically known shape of this ragged
925      tensor.  Ragged dimensions have a size of `None`.
926
927    Alias for `shape` property.
928
929    Examples:
930
931    >>> tf.ragged.constant([[0], [1, 2]]).get_shape()
932    TensorShape([2, None])
933
934    >>> tf.ragged.constant(
935    ...    [[[0, 1]], [[1, 2], [3, 4]]], ragged_rank=1).get_shape()
936    TensorShape([2, None, 2])
937
938    """
939    return self.shape
940
941  @property
942  def ragged_rank(self):
943    """The number of times the RaggedTensor's flat_values is partitioned.
944
945    Examples:
946
947    >>> values = tf.ragged.constant([[1, 2, 3], [4], [5, 6], [7, 8, 9, 10]])
948    >>> values.ragged_rank
949    1
950
951    >>> rt = tf.RaggedTensor.from_uniform_row_length(values, 2)
952    >>> rt.ragged_rank
953    2
954
955    Returns:
956      A Python `int` indicating the number of times the underlying `flat_values`
957      Tensor has been partitioned to add a new dimension.
958      I.e., `tf.rank(rt) = tf.rank(rt.flat_values) + rt.ragged_rank`.
959    """
960    values_is_ragged = isinstance(self._values, RaggedTensor)
961    return self._values.ragged_rank + 1 if values_is_ragged else 1
962
963  @property
964  def values(self):
965    """The concatenated rows for this ragged tensor.
966
967    `rt.values` is a potentially ragged tensor formed by flattening the two
968    outermost dimensions of `rt` into a single dimension.
969
970    `rt.values.shape = [nvals] + rt.shape[2:]` (where `nvals` is the
971    number of items in the outer two dimensions of `rt`).
972
973    `rt.ragged_rank = self.ragged_rank - 1`
974
975    Returns:
976      A potentially ragged tensor.
977
978    #### Example:
979
980    >>> rt = tf.ragged.constant([[3, 1, 4, 1], [], [5, 9, 2], [6], []])
981    >>> print(rt.values)
982    tf.Tensor([3 1 4 1 5 9 2 6], shape=(8,), dtype=int32)
983
984    """
985    return self._values
986
987  @property
988  def _nested_row_partitions(self):
989    """Returns the row partitions for this `RaggedTensor`."""
990    partitions = [self._row_partition]
991    rt_values = self.values
992    while isinstance(rt_values, RaggedTensor):
993      # pylint: disable=protected-access
994      partitions.append(rt_values._row_partition)
995      rt_values = rt_values.values
996    return tuple(partitions)
997
998  @property
999  def row_splits(self):
1000    """The row-split indices for this ragged tensor's `values`.
1001
1002    `rt.row_splits` specifies where the values for each row begin and end in
1003    `rt.values`.  In particular, the values for row `rt[i]` are stored in
1004    the slice `rt.values[rt.row_splits[i]:rt.row_splits[i+1]]`.
1005
1006    Returns:
1007      A 1-D integer `Tensor` with shape `[self.nrows+1]`.
1008      The returned tensor is non-empty, and is sorted in ascending order.
1009      `self.row_splits[0]` is zero, and `self.row_splits[-1]` is equal to
1010      `self.values.shape[0]`.
1011
1012    #### Example:
1013
1014    >>> rt = tf.ragged.constant([[3, 1, 4, 1], [], [5, 9, 2], [6], []])
1015    >>> print(rt.row_splits)  # indices of row splits in rt.values
1016    tf.Tensor([0 4 4 7 8 8], shape=(6,), dtype=int64)
1017
1018    """
1019    return self._row_partition.row_splits()
1020
1021  @property
1022  def uniform_row_length(self):
1023    """The length of each row in this ragged tensor, or None if rows are ragged.
1024
1025    >>> rt1 = tf.ragged.constant([[1, 2, 3], [4], [5, 6], [7, 8, 9, 10]])
1026    >>> print(rt1.uniform_row_length)  # rows are ragged.
1027    None
1028
1029    >>> rt2 = tf.RaggedTensor.from_uniform_row_length(
1030    ...     values=rt1, uniform_row_length=2)
1031    >>> print(rt2)
1032    <tf.RaggedTensor [[[1, 2, 3], [4]], [[5, 6], [7, 8, 9, 10]]]>
1033    >>> print(rt2.uniform_row_length)  # rows are not ragged (all have size 2).
1034    tf.Tensor(2, shape=(), dtype=int64)
1035
1036    A RaggedTensor's rows are only considered to be uniform (i.e. non-ragged)
1037    if it can be determined statically (at graph construction time) that the
1038    rows all have the same length.
1039
1040    Returns:
1041      A scalar integer `Tensor`, specifying the length of every row in this
1042      ragged tensor (for ragged tensors whose rows are uniform); or `None`
1043      (for ragged tensors whose rows are ragged).
1044    """
1045    return self._row_partition.uniform_row_length()
1046
1047  @property
1048  def flat_values(self):
1049    """The innermost `values` tensor for this ragged tensor.
1050
1051    Concretely, if `rt.values` is a `Tensor`, then `rt.flat_values` is
1052    `rt.values`; otherwise, `rt.flat_values` is `rt.values.flat_values`.
1053
1054    Conceptually, `flat_values` is the tensor formed by flattening the
1055    outermost dimension and all of the ragged dimensions into a single
1056    dimension.
1057
1058    `rt.flat_values.shape = [nvals] + rt.shape[rt.ragged_rank + 1:]`
1059    (where `nvals` is the number of items in the flattened dimensions).
1060
1061    Returns:
1062      A `Tensor`.
1063
1064    #### Example:
1065
1066    >>> rt = tf.ragged.constant([[[3, 1, 4, 1], [], [5, 9, 2]], [], [[6], []]])
1067    >>> print(rt.flat_values)
1068    tf.Tensor([3 1 4 1 5 9 2 6], shape=(8,), dtype=int32)
1069
1070    """
1071    rt_values = self.values
1072    while isinstance(rt_values, RaggedTensor):
1073      rt_values = rt_values.values
1074    return rt_values
1075
1076  @property
1077  def nested_row_splits(self):
1078    """A tuple containing the row_splits for all ragged dimensions.
1079
1080    `rt.nested_row_splits` is a tuple containing the `row_splits` tensors for
1081    all ragged dimensions in `rt`, ordered from outermost to innermost.  In
1082    particular, `rt.nested_row_splits = (rt.row_splits,) + value_splits` where:
1083
1084        * `value_splits = ()` if `rt.values` is a `Tensor`.
1085        * `value_splits = rt.values.nested_row_splits` otherwise.
1086
1087    Returns:
1088      A `tuple` of 1-D integer `Tensor`s.
1089
1090    #### Example:
1091
1092    >>> rt = tf.ragged.constant(
1093    ...     [[[[3, 1, 4, 1], [], [5, 9, 2]], [], [[6], []]]])
1094    >>> for i, splits in enumerate(rt.nested_row_splits):
1095    ...   print('Splits for dimension %d: %s' % (i+1, splits.numpy()))
1096    Splits for dimension 1: [0 3]
1097    Splits for dimension 2: [0 3 3 5]
1098    Splits for dimension 3: [0 4 4 7 8 8]
1099
1100    """
1101    rt_nested_splits = [self.row_splits]
1102    rt_values = self.values
1103    while isinstance(rt_values, RaggedTensor):
1104      rt_nested_splits.append(rt_values.row_splits)
1105      rt_values = rt_values.values
1106    return tuple(rt_nested_splits)
1107
1108  def value_rowids(self, name=None):
1109    """Returns the row indices for the `values` in this ragged tensor.
1110
1111    `rt.value_rowids()` corresponds one-to-one with the outermost dimension of
1112    `rt.values`, and specifies the row containing each value.  In particular,
1113    the row `rt[row]` consists of the values `rt.values[j]` where
1114    `rt.value_rowids()[j] == row`.
1115
1116    Args:
1117      name: A name prefix for the returned tensor (optional).
1118
1119    Returns:
1120      A 1-D integer `Tensor` with shape `self.values.shape[:1]`.
1121      The returned tensor is nonnegative, and is sorted in ascending order.
1122
1123    #### Example:
1124
1125    >>> rt = tf.ragged.constant([[3, 1, 4, 1], [], [5, 9, 2], [6], []])
1126    >>> print(rt.values)
1127    tf.Tensor([3 1 4 1 5 9 2 6], shape=(8,), dtype=int32)
1128    >>> print(rt.value_rowids())  # corresponds 1:1 with rt.values
1129    tf.Tensor([0 0 0 0 2 2 2 3], shape=(8,), dtype=int64)
1130
1131    """
1132    with ops.name_scope(name, "RaggedValueRowIds", [self]):
1133      return self._row_partition.value_rowids()
1134
1135  def nested_value_rowids(self, name=None):
1136    """Returns a tuple containing the value_rowids for all ragged dimensions.
1137
1138    `rt.nested_value_rowids` is a tuple containing the `value_rowids` tensors
1139    for
1140    all ragged dimensions in `rt`, ordered from outermost to innermost.  In
1141    particular, `rt.nested_value_rowids = (rt.value_rowids(),) + value_ids`
1142    where:
1143
1144    * `value_ids = ()` if `rt.values` is a `Tensor`.
1145    * `value_ids = rt.values.nested_value_rowids` otherwise.
1146
1147    Args:
1148      name: A name prefix for the returned tensors (optional).
1149
1150    Returns:
1151      A `tuple` of 1-D integer `Tensor`s.
1152
1153    #### Example:
1154
1155    >>> rt = tf.ragged.constant(
1156    ...     [[[[3, 1, 4, 1], [], [5, 9, 2]], [], [[6], []]]])
1157    >>> for i, ids in enumerate(rt.nested_value_rowids()):
1158    ...   print('row ids for dimension %d: %s' % (i+1, ids.numpy()))
1159    row ids for dimension 1: [0 0 0]
1160    row ids for dimension 2: [0 0 0 2 2]
1161    row ids for dimension 3: [0 0 0 0 2 2 2 3]
1162
1163    """
1164    with ops.name_scope(name, "RaggedNestedValueRowIds", [self]):
1165      rt_nested_ids = [self.value_rowids()]
1166      rt_values = self.values
1167      while isinstance(rt_values, RaggedTensor):
1168        rt_nested_ids.append(rt_values.value_rowids())
1169        rt_values = rt_values.values
1170      return tuple(rt_nested_ids)
1171
1172  def nrows(self, out_type=None, name=None):
1173    """Returns the number of rows in this ragged tensor.
1174
1175    I.e., the size of the outermost dimension of the tensor.
1176
1177    Args:
1178      out_type: `dtype` for the returned tensor.  Defaults to
1179        `self.row_splits.dtype`.
1180      name: A name prefix for the returned tensor (optional).
1181
1182    Returns:
1183      A scalar `Tensor` with dtype `out_type`.
1184
1185    #### Example:
1186
1187    >>> rt = tf.ragged.constant([[3, 1, 4, 1], [], [5, 9, 2], [6], []])
1188    >>> print(rt.nrows())  # rt has 5 rows.
1189    tf.Tensor(5, shape=(), dtype=int64)
1190
1191    """
1192    with ops.name_scope(name, "RaggedNRows", [self]):
1193      return self._row_partition.nrows(out_type=out_type)
1194
1195  def row_starts(self, name=None):
1196    """Returns the start indices for rows in this ragged tensor.
1197
1198    These indices specify where the values for each row begin in
1199    `self.values`.  `rt.row_starts()` is equal to `rt.row_splits[:-1]`.
1200
1201    Args:
1202      name: A name prefix for the returned tensor (optional).
1203
1204    Returns:
1205      A 1-D integer Tensor with shape `[nrows]`.
1206      The returned tensor is nonnegative, and is sorted in ascending order.
1207
1208    #### Example:
1209
1210    >>> rt = tf.ragged.constant([[3, 1, 4, 1], [], [5, 9, 2], [6], []])
1211    >>> print(rt.values)
1212    tf.Tensor([3 1 4 1 5 9 2 6], shape=(8,), dtype=int32)
1213    >>> print(rt.row_starts())  # indices of row starts in rt.values
1214    tf.Tensor([0 4 4 7 8], shape=(5,), dtype=int64)
1215
1216    """
1217    with ops.name_scope(name, "RaggedRowStarts", [self]):
1218      return self._row_partition.row_starts()
1219
1220  def row_limits(self, name=None):
1221    """Returns the limit indices for rows in this ragged tensor.
1222
1223    These indices specify where the values for each row end in
1224    `self.values`.  `rt.row_limits(self)` is equal to `rt.row_splits[:-1]`.
1225
1226    Args:
1227      name: A name prefix for the returned tensor (optional).
1228
1229    Returns:
1230      A 1-D integer Tensor with shape `[nrows]`.
1231      The returned tensor is nonnegative, and is sorted in ascending order.
1232
1233    #### Example:
1234
1235    >>> rt = tf.ragged.constant([[3, 1, 4, 1], [], [5, 9, 2], [6], []])
1236    >>> print(rt.values)
1237    tf.Tensor([3 1 4 1 5 9 2 6], shape=(8,), dtype=int32)
1238    >>> print(rt.row_limits())  # indices of row limits in rt.values
1239    tf.Tensor([4 4 7 8 8], shape=(5,), dtype=int64)
1240
1241    """
1242    with ops.name_scope(name, "RaggedRowLimits", [self]):
1243      return self._row_partition.row_limits()
1244
1245  def row_lengths(self, axis=1, name=None):
1246    """Returns the lengths of the rows in this ragged tensor.
1247
1248    `rt.row_lengths()[i]` indicates the number of values in the
1249    `i`th row of `rt`.
1250
1251    Args:
1252      axis: An integer constant indicating the axis whose row lengths should be
1253        returned.
1254      name: A name prefix for the returned tensor (optional).
1255
1256    Returns:
1257      A potentially ragged integer Tensor with shape `self.shape[:axis]`.
1258
1259    Raises:
1260      ValueError: If `axis` is out of bounds.
1261
1262    #### Example:
1263
1264    >>> rt = tf.ragged.constant(
1265    ...     [[[3, 1, 4], [1]], [], [[5, 9], [2]], [[6]], []])
1266    >>> print(rt.row_lengths())  # lengths of rows in rt
1267    tf.Tensor([2 0 2 1 0], shape=(5,), dtype=int64)
1268    >>> print(rt.row_lengths(axis=2))  # lengths of axis=2 rows.
1269    <tf.RaggedTensor [[3, 1], [], [2, 1], [1], []]>
1270
1271    """
1272    if axis == 0:
1273      return self._row_partition.nrows()
1274
1275    if axis == 1:
1276      return self._row_partition.row_lengths()
1277
1278    with ops.name_scope(name, "RaggedRowLengths", [self]):
1279      axis = array_ops.get_positive_axis(
1280          axis, self.shape.rank, ndims_name="rank(self)")
1281      if axis == 0:
1282        return self.nrows()
1283      elif axis == 1:
1284        splits = self.row_splits
1285        return splits[1:] - splits[:-1]
1286      elif isinstance(self.values, RaggedTensor):
1287        return self.with_values(self.values.row_lengths(axis - 1))
1288      else:
1289        shape = array_ops.shape(self.values, out_type=self._row_partition.dtype)
1290        return self.with_values(
1291            array_ops.ones(shape[:axis - 1], self._row_partition.dtype) *
1292            shape[axis - 1])
1293
1294  def nested_row_lengths(self, name=None):
1295    """Returns a tuple containing the row_lengths for all ragged dimensions.
1296
1297    `rt.nested_row_lengths()` is a tuple containing the `row_lengths` tensors
1298    for all ragged dimensions in `rt`, ordered from outermost to innermost.
1299
1300    Args:
1301      name: A name prefix for the returned tensors (optional).
1302
1303    Returns:
1304      A `tuple` of 1-D integer `Tensors`.  The length of the tuple is equal to
1305      `self.ragged_rank`.
1306    """
1307    with ops.name_scope(name, "RaggedNestedRowLengths", [self]):
1308      rt_nested_row_lengths = []
1309      rt = self
1310      while isinstance(rt, RaggedTensor):
1311        rt_nested_row_lengths.append(rt.row_lengths())
1312        rt = rt.values
1313      return tuple(rt_nested_row_lengths)
1314
1315  def bounding_shape(self, axis=None, name=None, out_type=None):
1316    """Returns the tight bounding box shape for this `RaggedTensor`.
1317
1318    Args:
1319      axis: An integer scalar or vector indicating which axes to return the
1320        bounding box for.  If not specified, then the full bounding box is
1321        returned.
1322      name: A name prefix for the returned tensor (optional).
1323      out_type: `dtype` for the returned tensor.  Defaults to
1324        `self.row_splits.dtype`.
1325
1326    Returns:
1327      An integer `Tensor` (`dtype=self.row_splits.dtype`).  If `axis` is not
1328      specified, then `output` is a vector with
1329      `output.shape=[self.shape.ndims]`.  If `axis` is a scalar, then the
1330      `output` is a scalar.  If `axis` is a vector, then `output` is a vector,
1331      where `output[i]` is the bounding size for dimension `axis[i]`.
1332
1333    #### Example:
1334
1335    >>> rt = tf.ragged.constant([[1, 2, 3, 4], [5], [], [6, 7, 8, 9], [10]])
1336    >>> rt.bounding_shape().numpy()
1337    array([5, 4])
1338
1339    """
1340    if out_type is None:
1341      out_type = self._row_partition.dtype
1342    else:
1343      out_type = dtypes.as_dtype(out_type)
1344    with ops.name_scope(name, "RaggedBoundingBox", [self, axis]):
1345      nested_splits = self.nested_row_splits
1346      rt_flat_values = self.flat_values
1347
1348      # Optimized special cases for when axis=0 or axis=1:
1349      if isinstance(axis, int):
1350        if axis == 0:
1351          return array_ops.shape(nested_splits[0], out_type=out_type)[0] - 1
1352        elif axis == 1:
1353          result = math_ops.maximum(math_ops.reduce_max(self.row_lengths()), 0)
1354          if out_type != self._row_partition.dtype:
1355            result = math_ops.cast(result, out_type)
1356          return result
1357
1358      splits_shape = array_ops.shape(self.row_splits, out_type=out_type)
1359      flat_values_shape = array_ops.shape(rt_flat_values, out_type=out_type)
1360
1361      ragged_dimensions = [splits_shape[0] - 1] + [
1362          math_ops.maximum(math_ops.reduce_max(splits[1:] - splits[:-1]), 0)
1363          for splits in nested_splits
1364      ]
1365      inner_dimensions = flat_values_shape[1:]
1366
1367      if out_type != self._row_partition.dtype:
1368        ragged_dimensions = [
1369            math_ops.cast(d, out_type) for d in ragged_dimensions
1370        ]
1371      bbox = array_ops.concat(
1372          [array_ops.stack(ragged_dimensions), inner_dimensions], axis=0)
1373      return bbox if axis is None else array_ops.gather(bbox, axis)
1374
1375  #=============================================================================
1376  # Transformation
1377  #=============================================================================
1378
1379  def with_values(self, new_values):
1380    """Returns a copy of `self` with `values` replaced by `new_value`.
1381
1382    Preserves cached row-partitioning tensors such as `self.cached_nrows` and
1383    `self.cached_value_rowids` if they have values.
1384
1385    Args:
1386      new_values: Potentially ragged tensor to use as the `values` for the
1387        returned `RaggedTensor`.  Must have `rank > 0`, and must have the same
1388        number of rows as `self.values`.
1389
1390    Returns:
1391      A `RaggedTensor`.  `result.rank = 1 + new_values.rank`.
1392      `result.ragged_rank = 1 + new_values.ragged_rank`
1393    """
1394    new_values = _convert_to_ragged_tensor_values(new_values)
1395    new_values.shape.with_rank_at_least(1)
1396    self.values.shape[:1].assert_is_compatible_with(new_values.shape[:1])
1397    if (isinstance(new_values, RaggedTensor) and
1398        self._row_partition.dtype != new_values.row_splits.dtype):
1399      if not ragged_config.auto_cast_partition_dtype():
1400        raise ValueError("self and new_values have mismatched row_splits "
1401                         "dtypes; use RaggedTensor.with_row_splits_dtype() to "
1402                         "convert them to compatible dtypes.")
1403      new_values = new_values.with_row_splits_dtype(dtypes.int64)
1404      return self.with_row_splits_dtype(dtypes.int64).with_values(new_values)
1405    return RaggedTensor(
1406        values=new_values, row_partition=self._row_partition, internal=True)
1407
1408  def with_flat_values(self, new_values):
1409    """Returns a copy of `self` with `flat_values` replaced by `new_value`.
1410
1411    Preserves cached row-partitioning tensors such as `self.cached_nrows` and
1412    `self.cached_value_rowids` if they have values.
1413
1414    Args:
1415      new_values: Potentially ragged tensor that should replace
1416        `self.flat_values`.  Must have `rank > 0`, and must have the same number
1417        of rows as `self.flat_values`.
1418
1419    Returns:
1420      A `RaggedTensor`.
1421      `result.rank = self.ragged_rank + new_values.rank`.
1422      `result.ragged_rank = self.ragged_rank + new_values.ragged_rank`.
1423    """
1424    if isinstance(self._values, RaggedTensor):
1425      return self.with_values(self.values.with_flat_values(new_values))
1426    else:
1427      new_values = _convert_to_ragged_tensor_values(new_values)
1428    return self.with_values(new_values)
1429
1430  def with_row_splits_dtype(self, dtype):
1431    """Returns a copy of this RaggedTensor with the given `row_splits` dtype.
1432
1433    For RaggedTensors with multiple ragged dimensions, the `row_splits` for all
1434    nested `RaggedTensor` objects are cast to the given dtype.
1435
1436    Args:
1437      dtype: The dtype for `row_splits`.  One of `tf.int32` or `tf.int64`.
1438
1439    Returns:
1440      A copy of this RaggedTensor, with the `row_splits` cast to the given
1441      type.
1442    """
1443    dtype = dtypes.as_dtype(dtype)
1444    if dtype not in (dtypes.int32, dtypes.int64):
1445      raise ValueError(f"Argument `row_splits` dtype must be int32 or int64. "
1446                       f"Received {dtype}.")
1447    if self._row_partition.dtype == dtype:
1448      return self
1449    current_values = self._values
1450    if isinstance(current_values, RaggedTensor):
1451      return RaggedTensor(
1452          values=current_values.with_row_splits_dtype(dtype),
1453          row_partition=self._row_partition.with_row_splits_dtype(dtype),
1454          internal=True)
1455    else:
1456      return RaggedTensor(
1457          values=current_values,
1458          row_partition=self._row_partition.with_row_splits_dtype(dtype),
1459          internal=True)
1460
1461  def merge_dims(self, outer_axis, inner_axis):
1462    """Merges outer_axis...inner_axis into a single dimension.
1463
1464    Returns a copy of this RaggedTensor with the specified range of dimensions
1465    flattened into a single dimension, with elements in row-major order.
1466
1467    #### Examples:
1468
1469    >>> rt = tf.ragged.constant([[[1, 2], [3]], [[4, 5, 6]]])
1470    >>> print(rt.merge_dims(0, 1))
1471    <tf.RaggedTensor [[1, 2], [3], [4, 5, 6]]>
1472    >>> print(rt.merge_dims(1, 2))
1473    <tf.RaggedTensor [[1, 2, 3], [4, 5, 6]]>
1474    >>> print(rt.merge_dims(0, 2))
1475    tf.Tensor([1 2 3 4 5 6], shape=(6,), dtype=int32)
1476
1477    To mimic the behavior of `np.flatten` (which flattens all dimensions), use
1478    `rt.merge_dims(0, -1).  To mimic the behavior of `tf.layers.Flatten` (which
1479    flattens all dimensions except the outermost batch dimension), use
1480    `rt.merge_dims(1, -1)`.
1481
1482    Args:
1483      outer_axis: `int`: The first dimension in the range of dimensions to
1484        merge. May be negative if `self.shape.rank` is statically known.
1485      inner_axis: `int`: The last dimension in the range of dimensions to merge.
1486        May be negative if `self.shape.rank` is statically known.
1487
1488    Returns:
1489      A copy of this tensor, with the specified dimensions merged into a
1490      single dimension.  The shape of the returned tensor will be
1491      `self.shape[:outer_axis] + [N] + self.shape[inner_axis + 1:]`, where `N`
1492      is the total number of slices in the merged dimensions.
1493    """
1494    outer_axis = array_ops.get_positive_axis(
1495        outer_axis,
1496        self.shape.rank,
1497        axis_name="outer_axis",
1498        ndims_name="rank(self)")
1499    inner_axis = array_ops.get_positive_axis(
1500        inner_axis,
1501        self.shape.rank,
1502        axis_name="inner_axis",
1503        ndims_name="rank(self)")
1504    if not outer_axis <= inner_axis:
1505      raise ValueError(f"Expected outer_axis ({outer_axis}) to be less than or "
1506                       f"equal to inner_axis ({inner_axis}).")
1507    return merge_dims(self, outer_axis, inner_axis)
1508
1509  def _set_shape(self, shape):
1510    """Updates the static shape of `self` to be `shape`.
1511
1512    * If a dimension of `shape` has known rank, and is encoded via
1513      partitioning, then this will update the corresponding partition to
1514      define `_uniform_row_length` and `nrows`.
1515    * If a dimension of `shape` has a known rank, and is encoded as one
1516      of the `flat_values` dimensions, then `flat_values.set_shape()` will
1517      be used to update its shape.
1518
1519    Warning: Using this method to assert an incorrect shape for a RaggedTensor
1520    (i.e., one that's not consistent with its actual shape) can cause
1521    segmentation faults and very difficult-to-diagnose behavior.  Only use this
1522    method if you are certain that the shape is correct.
1523
1524    Args:
1525      shape: `tf.TensorShape` specifying the shape for this `RaggedTensor`.
1526    """
1527    # TODO(edloper): Refactor this to not directly access private members
1528    # of RowPartition.
1529    # pylint: disable=protected-access
1530
1531    shape = tensor_shape.as_shape(shape)
1532    if shape.rank is None:
1533      return  # Nothing to do.
1534
1535    shape = shape.as_list()
1536
1537    # Outermost dimension
1538    if shape[0] is not None:
1539      self._row_partition._row_splits.set_shape(shape[0] + 1)
1540
1541    # Partitioned dimensions
1542    dtype = self._row_partition.dtype
1543    for i, partition in enumerate(self._nested_row_partitions):
1544      size = shape[i + 1]
1545      if size is not None:
1546        if partition._uniform_row_length is not None:
1547          old_row_length = tensor_util.constant_value(
1548              partition._uniform_row_length)
1549          if old_row_length is not None:
1550            if size == old_row_length:
1551              continue  # already have shape info for this axis.
1552            else:
1553              raise ValueError(f"Inconsistent size for axis {i + 1}: "
1554                               f"{old_row_length} vs. {size}.")
1555        partition._uniform_row_length = ops.convert_to_tensor(size, dtype)
1556        if partition._nrows is None:
1557          partition._nrows = array_ops.size(partition._row_splits) - 1
1558
1559    # Inner dimensions
1560    flat_shape = tensor_shape.as_shape([None] + shape[self.ragged_rank + 1:])
1561    self.flat_values.set_shape(flat_shape)
1562
1563  #=============================================================================
1564  # Tensor Type Conversions
1565  #=============================================================================
1566
1567  @classmethod
1568  @dispatch.add_dispatch_support
1569  def from_tensor(cls,
1570                  tensor,
1571                  lengths=None,
1572                  padding=None,
1573                  ragged_rank=1,
1574                  name=None,
1575                  row_splits_dtype=dtypes.int64):
1576    """Converts a `tf.Tensor` into a `RaggedTensor`.
1577
1578    The set of absent/default values may be specified using a vector of lengths
1579    or a padding value (but not both).  If `lengths` is specified, then the
1580    output tensor will satisfy `output[row] = tensor[row][:lengths[row]]`. If
1581    'lengths' is a list of lists or tuple of lists, those lists will be used
1582    as nested row lengths. If `padding` is specified, then any row *suffix*
1583    consisting entirely of `padding` will be excluded from the returned
1584    `RaggedTensor`.  If neither `lengths` nor `padding` is specified, then the
1585    returned `RaggedTensor` will have no absent/default values.
1586
1587    Examples:
1588
1589    >>> dt = tf.constant([[5, 7, 0], [0, 3, 0], [6, 0, 0]])
1590    >>> tf.RaggedTensor.from_tensor(dt)
1591    <tf.RaggedTensor [[5, 7, 0], [0, 3, 0], [6, 0, 0]]>
1592    >>> tf.RaggedTensor.from_tensor(dt, lengths=[1, 0, 3])
1593    <tf.RaggedTensor [[5], [], [6, 0, 0]]>
1594
1595    >>> tf.RaggedTensor.from_tensor(dt, padding=0)
1596    <tf.RaggedTensor [[5, 7], [0, 3], [6]]>
1597
1598    >>> dt = tf.constant([[[5, 0], [7, 0], [0, 0]],
1599    ...                   [[0, 0], [3, 0], [0, 0]],
1600    ...                   [[6, 0], [0, 0], [0, 0]]])
1601    >>> tf.RaggedTensor.from_tensor(dt, lengths=([2, 0, 3], [1, 1, 2, 0, 1]))
1602    <tf.RaggedTensor [[[5], [7]], [], [[6, 0], [], [0]]]>
1603
1604    Args:
1605      tensor: The `Tensor` to convert.  Must have rank `ragged_rank + 1` or
1606        higher.
1607      lengths: An optional set of row lengths, specified using a 1-D integer
1608        `Tensor` whose length is equal to `tensor.shape[0]` (the number of rows
1609        in `tensor`).  If specified, then `output[row]` will contain
1610        `tensor[row][:lengths[row]]`.  Negative lengths are treated as zero. You
1611          may optionally pass a list or tuple of lengths to this argument, which
1612          will be used as nested row lengths to construct a ragged tensor with
1613          multiple ragged dimensions.
1614      padding: An optional padding value.  If specified, then any row suffix
1615        consisting entirely of `padding` will be excluded from the returned
1616        RaggedTensor.  `padding` is a `Tensor` with the same dtype as `tensor`
1617        and with `shape=tensor.shape[ragged_rank + 1:]`.
1618      ragged_rank: Integer specifying the ragged rank for the returned
1619        `RaggedTensor`.  Must be greater than zero.
1620      name: A name prefix for the returned tensors (optional).
1621      row_splits_dtype: `dtype` for the returned `RaggedTensor`'s `row_splits`
1622        tensor.  One of `tf.int32` or `tf.int64`.
1623
1624    Returns:
1625      A `RaggedTensor` with the specified `ragged_rank`.  The shape of the
1626      returned ragged tensor is compatible with the shape of `tensor`.
1627    Raises:
1628      ValueError: If both `lengths` and `padding` are specified.
1629    """
1630    row_splits_dtype = dtypes.as_dtype(row_splits_dtype)
1631    if lengths is not None and padding is not None:
1632      raise ValueError("Specify argument `lengths` or `padding`, but not both.")
1633    if not isinstance(ragged_rank, int):
1634      raise TypeError(f"Argument `ragged_rank` must be an int. "
1635                      f"Received {ragged_rank}.")
1636    if ragged_rank <= 0:
1637      raise ValueError(f"Argument `ragged_rank` must be greater than 0. "
1638                       f"Received {ragged_rank}.")
1639
1640    with ops.name_scope(name, "RaggedFromTensor", [tensor, lengths, padding]):
1641      tensor = ops.convert_to_tensor(tensor, name="tensor")
1642      tensor.shape.with_rank_at_least(ragged_rank + 1)
1643      input_shape = array_ops.shape(tensor, out_type=row_splits_dtype)
1644      ncols = input_shape[1]
1645
1646      # Handle nested row lengths.
1647      if (lengths is not None and isinstance(lengths, (list, tuple)) and
1648          len(lengths) and not isinstance(lengths[0], (int, float))):
1649        if ragged_rank not in (1, len(lengths)):
1650          # Note: we accept `ragged_rank=1` here because it's the default value;
1651          # i.e., if the user passes in a tuple of lengths, but doesn't specify
1652          # ragged_rank, then we should use that tuple to determine ragged_rank.
1653          # We only want to complain if they pass in an explicit ragged_rank
1654          # that doesn't match len(lengths).
1655          raise ValueError(f"If Argument `lengths` is a tuple of row_lengths, "
1656                           f"argument `ragged_rank` must be "
1657                           f"len(lengths): {len(lengths)}. Received "
1658                           f"ragged_rank: {ragged_rank}.")
1659        # Rather than reconstructing the tensor mask directly, we can
1660        # recreate it as a boolean RaggedTensor, then densify that and use
1661        # that as the mask to clear out the unused data in the passed tensor.
1662        tensor.shape.with_rank_at_least(len(lengths) + 1)
1663        num_tokens = math_ops.reduce_sum(lengths[-1])
1664        ones_mask = array_ops.ones([num_tokens], dtype=dtypes.bool)
1665        ragged_mask = cls.from_nested_row_lengths(
1666            ones_mask, lengths, validate=False)
1667        dense_ragged_mask = ragged_mask.to_tensor(default_value=False)
1668        masked_data = array_ops.boolean_mask(tensor, dense_ragged_mask)
1669        return cls.from_nested_row_lengths(masked_data, lengths, validate=False)
1670
1671      # Handle ragged_rank>1 via recursion:
1672      # If the output should have multiple ragged dimensions, then first
1673      # flatten the tensor to eliminate all but the last ragged dimension,
1674      # and recursively convert that flattened tensor.  Then add on the splits
1675      # for the dimensions that we flattened out.
1676      if ragged_rank > 1:
1677        if tensor.shape.is_fully_defined():
1678          input_shape = tensor.shape.as_list()
1679          # The total number of elements in each  dimension.  E.g., if
1680          # input_shape=[3, 4, 5, 6], then dim[2] has 3*4*5 elements in total.
1681          dim_size = np.cumprod(input_shape)
1682          new_shape = [dim_size[ragged_rank - 1]] + input_shape[ragged_rank:]
1683        else:
1684          dim_size = math_ops.cumprod(input_shape)
1685          new_shape = array_ops.concat(
1686              [[dim_size[ragged_rank - 1]], input_shape[ragged_rank:]], axis=0)
1687        flattened = array_ops.reshape(tensor, new_shape)
1688        result = cls.from_tensor(
1689            flattened, lengths, padding, row_splits_dtype=row_splits_dtype)
1690
1691        for axis in range(ragged_rank - 1, 0, -1):
1692          dim_len = tensor_shape.dimension_at_index(tensor.shape, axis).value
1693          if dim_len is None:
1694            dim_len = input_shape[axis]
1695          else:
1696            dim_len = constant_op.constant(dim_len, row_splits_dtype)
1697          result = RaggedTensor.from_uniform_row_length(
1698              values=result,
1699              uniform_row_length=dim_len,
1700              nrows=dim_size[axis - 1],
1701              validate=False)
1702        return result
1703
1704      # If padding was specified, then use it to find row lengths.
1705      if padding is not None:
1706        padding = ops.convert_to_tensor(
1707            padding, name="padding", dtype=tensor.dtype)
1708        padding.shape.assert_is_compatible_with(tensor.shape[2:])
1709
1710        # Find places where the padding is equal to the tensor.  (This will
1711        # broadcast `padding` across the outermost 2 dimensions of `tensor`,
1712        # so `has_default_value.shape = tensor.shape`.)
1713        has_default_value = math_ops.equal(padding, tensor)
1714
1715        # If the padding isn't a scalar, then require that all values in the
1716        # padding match each item in the tensor.  After this block of code,
1717        # `has_default.shape = tensor.shape[:2]`.  (Unfortunately, we can't just
1718        # use reduce_all for both cases, becaue when you pass an empty `axis`
1719        # list to reduce_all, it reduces all axes; but we want it to reduce no
1720        # axes -- i.e., to be a no-op.)
1721        tensor_rank = array_ops.rank(tensor)
1722        reduce_axis = math_ops.range(2, tensor_rank)
1723        has_default = control_flow_ops.cond(
1724            tensor_rank > 2,
1725            lambda: math_ops.reduce_all(has_default_value, axis=reduce_axis),
1726            lambda: has_default_value)
1727        has_default.set_shape(tensor_shape.TensorShape([None, None]))
1728        has_default.set_shape(tensor.shape[:2])
1729
1730        # Use has_default to find the length of each row: for each
1731        # non-default item in a row, calculate the length that the row needs to
1732        # have to include that item; and then take the max of those values
1733        # (across each row).
1734        has_nondefault = math_ops.logical_not(has_default)
1735        has_nondefault = math_ops.cast(has_nondefault, row_splits_dtype)
1736        length_for_nondefault_value = (
1737            has_nondefault *
1738            array_ops.expand_dims(math_ops.range(1, ncols + 1), 0))
1739        lengths = math_ops.reduce_max(length_for_nondefault_value, axis=1)
1740
1741      if lengths is not None:
1742        # If we have lengths (either directly supplied, or computed from
1743        # paddings), then use those to construct splits; and then use masking
1744        # to get the corresponding values.
1745        lengths = ragged_util.convert_to_int_tensor(lengths, "lengths",
1746                                                    row_splits_dtype)
1747        lengths.shape.assert_has_rank(1)
1748        lengths = math_ops.minimum(lengths, ncols)
1749        lengths = math_ops.maximum(lengths, 0)
1750        limits = math_ops.cumsum(lengths)
1751        splits = array_ops.concat(
1752            [array_ops.zeros([1], row_splits_dtype), limits], axis=0)
1753        mask = array_ops.sequence_mask(lengths, maxlen=ncols)
1754        values = array_ops.boolean_mask(tensor, mask)
1755        return cls.from_row_splits(values, splits, validate=False)
1756
1757      # If neither padding nor lengths were specified, then create a splits
1758      # vector that contains no default values, and reshape the input tensor
1759      # to form the values for the RaggedTensor.
1760      values_shape = array_ops.concat(
1761          [[input_shape[0] * input_shape[1]], input_shape[2:]], axis=0)
1762      values = array_ops.reshape(tensor, values_shape)
1763      const_nrows = tensor_shape.dimension_at_index(tensor.shape, 0).value
1764      const_ncols = tensor_shape.dimension_at_index(tensor.shape, 1).value
1765      if const_nrows is not None:
1766        nrows = constant_op.constant(const_nrows, row_splits_dtype)
1767      else:
1768        nrows = input_shape[0]
1769      if const_ncols is not None:
1770        ncols = constant_op.constant(const_ncols, row_splits_dtype)
1771      else:
1772        ncols = input_shape[1]
1773      return RaggedTensor.from_uniform_row_length(
1774          values=values, uniform_row_length=ncols, nrows=nrows, validate=False)
1775
1776  def to_tensor(self, default_value=None, name=None, shape=None):
1777    """Converts this `RaggedTensor` into a `tf.Tensor`.
1778
1779    If `shape` is specified, then the result is padded and/or truncated to
1780    the specified shape.
1781
1782    Examples:
1783
1784    >>> rt = tf.ragged.constant([[9, 8, 7], [], [6, 5], [4]])
1785    >>> print(rt.to_tensor())
1786    tf.Tensor(
1787        [[9 8 7] [0 0 0] [6 5 0] [4 0 0]], shape=(4, 3), dtype=int32)
1788    >>> print(rt.to_tensor(shape=[5, 2]))
1789    tf.Tensor(
1790        [[9 8] [0 0] [6 5] [4 0] [0 0]], shape=(5, 2), dtype=int32)
1791
1792    Args:
1793      default_value: Value to set for indices not specified in `self`. Defaults
1794        to zero.  `default_value` must be broadcastable to
1795        `self.shape[self.ragged_rank + 1:]`.
1796      name: A name prefix for the returned tensors (optional).
1797      shape: The shape of the resulting dense tensor.  In particular,
1798        `result.shape[i]` is `shape[i]` (if `shape[i]` is not None), or
1799        `self.bounding_shape(i)` (otherwise).`shape.rank` must be `None` or
1800        equal to `self.rank`.
1801
1802    Returns:
1803      A `Tensor` with shape `ragged.bounding_shape(self)` and the
1804      values specified by the non-empty values in `self`.  Empty values are
1805      assigned `default_value`.
1806    """
1807    with ops.name_scope(name, "RaggedToTensor", [self, default_value, shape]):
1808      if default_value is not None:
1809        default_value = ops.convert_to_tensor(
1810            default_value, name="default_value", dtype=self.dtype)
1811      type_tensor_pairs = _get_row_partition_type_tensor_pairs(self)
1812      row_partition_types = [x[0] for x in type_tensor_pairs]
1813      row_partition_tensors = [x[1] for x in type_tensor_pairs]
1814      if default_value is None:
1815        default_value = array_ops.zeros((), self.dtype)
1816
1817      if (isinstance(shape, (list, tuple)) and
1818          any(isinstance(v, ops.Tensor) for v in shape) and
1819          all(isinstance(v, (int, ops.Tensor)) for v in shape)):
1820        shape = array_ops.stack(shape)
1821
1822      shape_tensor = _shape_as_tensor(shape, row_partition_tensors[0].dtype)
1823      tensor = gen_ragged_conversion_ops.ragged_tensor_to_tensor(
1824          shape=shape_tensor,
1825          values=self.flat_values,
1826          default_value=default_value,
1827          row_partition_types=row_partition_types,
1828          row_partition_tensors=row_partition_tensors)
1829
1830      ragged_shape = self.shape
1831
1832      if ragged_shape.rank is not None and not isinstance(shape, ops.Tensor):
1833        # Merged self.shape and shape, favoring the second one as it takes
1834        # into account potential padding added to the output.
1835        shape = tensor_shape.as_shape(shape)
1836        if shape.rank is None:
1837          output_shape = ragged_shape
1838        else:
1839          # At this point we can assume that hshape.rank == ragged_shape.rank
1840          # because otherwise it would have failed earlier.
1841          output_shape = [
1842              s1 if s1 is not None else s2
1843              for (s1, s2) in zip(shape.as_list(), ragged_shape.as_list())
1844          ]
1845        tensor.set_shape(output_shape)
1846
1847      return tensor
1848
1849  @classmethod
1850  @dispatch.add_dispatch_support
1851  def from_sparse(cls, st_input, name=None, row_splits_dtype=dtypes.int64):
1852    """Converts a 2D `tf.sparse.SparseTensor` to a `RaggedTensor`.
1853
1854    Each row of the `output` `RaggedTensor` will contain the explicit values
1855    from the same row in `st_input`.  `st_input` must be ragged-right.  If not
1856    it is not ragged-right, then an error will be generated.
1857
1858    Example:
1859
1860    >>> indices = [[0, 0], [0, 1], [0, 2], [1, 0], [3, 0]]
1861    >>> st = tf.sparse.SparseTensor(indices=indices,
1862    ...                             values=[1, 2, 3, 4, 5],
1863    ...                             dense_shape=[4, 3])
1864    >>> tf.RaggedTensor.from_sparse(st).to_list()
1865    [[1, 2, 3], [4], [], [5]]
1866
1867    Currently, only two-dimensional `SparseTensors` are supported.
1868
1869    Args:
1870      st_input: The sparse tensor to convert.  Must have rank 2.
1871      name: A name prefix for the returned tensors (optional).
1872      row_splits_dtype: `dtype` for the returned `RaggedTensor`'s `row_splits`
1873        tensor.  One of `tf.int32` or `tf.int64`.
1874
1875    Returns:
1876      A `RaggedTensor` with the same values as `st_input`.
1877      `output.ragged_rank = rank(st_input) - 1`.
1878      `output.shape = [st_input.dense_shape[0], None]`.
1879    Raises:
1880      ValueError: If the number of dimensions in `st_input` is not known
1881        statically, or is not two.
1882    """
1883    row_splits_dtype = dtypes.as_dtype(row_splits_dtype)
1884    if not sparse_tensor.is_sparse(st_input):
1885      raise TypeError(f"Argument `st_input` must be of type SparseTensor, but "
1886                      f"is of type {type(st_input).__name__}.")
1887    with ops.name_scope(name, "RaggedFromSparse", [st_input]):
1888      st_input = sparse_tensor.convert_to_tensor_or_sparse_tensor(
1889          st_input, name="st_input")
1890
1891      if st_input.dense_shape.shape.ndims is None:
1892        static_rank_from_dense_shape = None
1893      else:
1894        static_rank_from_dense_shape = st_input.dense_shape.shape.dims[0].value
1895
1896      if st_input.indices.shape.ndims is None:
1897        static_rank_from_indices = None
1898      else:
1899        static_rank_from_indices = st_input.indices.shape.dims[1].value
1900
1901      if static_rank_from_dense_shape != 2 and static_rank_from_indices != 2:
1902        raise ValueError("rank(st_input) must be 2.")
1903
1904      with ops.control_dependencies(
1905          _assert_sparse_indices_are_ragged_right(st_input.indices)):
1906        # Treat sparse row indices as segment ids to generate a splits tensor
1907        # thta we can pair with the sparse tensor values.  (Ignore sparse column
1908        # indices.)
1909        segment_ids = math_ops.cast(st_input.indices[:, 0], row_splits_dtype)
1910        num_segments = math_ops.cast(st_input.dense_shape[0], row_splits_dtype)
1911        return cls.from_value_rowids(
1912            st_input.values, segment_ids, num_segments, validate=False)
1913
1914  def to_sparse(self, name=None):
1915    """Converts this `RaggedTensor` into a `tf.sparse.SparseTensor`.
1916
1917    Example:
1918
1919    >>> rt = tf.ragged.constant([[1, 2, 3], [4], [], [5, 6]])
1920    >>> print(rt.to_sparse())
1921    SparseTensor(indices=tf.Tensor(
1922                     [[0 0] [0 1] [0 2] [1 0] [3 0] [3 1]],
1923                     shape=(6, 2), dtype=int64),
1924                 values=tf.Tensor([1 2 3 4 5 6], shape=(6,), dtype=int32),
1925                 dense_shape=tf.Tensor([4 3], shape=(2,), dtype=int64))
1926
1927    Args:
1928      name: A name prefix for the returned tensors (optional).
1929
1930    Returns:
1931      A SparseTensor with the same values as `self`.
1932    """
1933    with ops.name_scope(name, "RaggedToSparse", [self]):
1934      result = gen_ragged_conversion_ops.ragged_tensor_to_sparse(
1935          self.nested_row_splits, self.flat_values, name=name)
1936      return sparse_tensor.SparseTensor(result.sparse_indices,
1937                                        result.sparse_values,
1938                                        result.sparse_dense_shape)
1939
1940  @classmethod
1941  def _from_variant(cls,
1942                    variant,
1943                    dtype,
1944                    output_ragged_rank,
1945                    input_ragged_rank=None,
1946                    row_splits_dtype=dtypes.int64,
1947                    name=None):
1948    """Converts a `variant` Tensor into a `RaggedTensor`.
1949
1950    The input `variant` could be a scalar, meaning it encodes a single
1951    `RaggedTensor` with ragged_rank `output_ragged_rank`. Alternatively it could
1952    have an arbitrary rank, in which case each element is decoded into a
1953    `RaggedTensor` with ragged_rank `input_ragged_rank` and these are then
1954    stacked according to the input shape to output a single `RaggedTensor`
1955    with ragged_rank `output_ragged_rank`. If `input_ragged_rank` is not
1956    provided, it is inferred dynamically as `output_ragged_rank` -
1957    `rank(variant)`. If `input_ragged_rank` is provided, the following must be
1958    true: `output_ragged_rank` = `input_ragged_rank` + `rank(variant)`.
1959
1960    Example:
1961
1962    >>> rt = tf.ragged.constant([[0], [1, 2]])
1963    >>> et = rt._to_variant()
1964    >>> stacked_et = tf.stack([et, et])
1965    >>> tf.RaggedTensor._from_variant(  # scalar input.
1966    ...     et, dtype=tf.int32, output_ragged_rank=1).to_list()
1967    [[0], [1, 2]]
1968    >>> tf.RaggedTensor._from_variant(  # batched input.
1969    ...     stacked_et, dtype=tf.int32, output_ragged_rank=2).to_list()
1970    [[[0], [1, 2]], [[0], [1, 2]]]
1971
1972    Args:
1973      variant: A `variant` Tensor representing an encoded (possibly
1974        nested-batched) `RaggedTensor`.
1975      dtype: The dtype of the encoded `RaggedTensor`.
1976      output_ragged_rank: The expected ragged rank of the output `RaggedTensor`.
1977      input_ragged_rank: The ragged rank of each encoded `RaggedTensor`. This is
1978        optional and inferred dynamically if not provided.
1979      row_splits_dtype: `dtype` for the RaggedTensor's `row_splits` tensor. One
1980        of `tf.int32` or `tf.int64`.
1981      name: A name prefix for the returned tensors (optional).
1982
1983    Returns:
1984      A `RaggedTensor` of dtype `dtype` and ragged rank `output_ragged_rank`.
1985
1986    Raises:
1987      ValueError: If the input rank is known, `input_ragged_rank` is provided
1988          and `output_ragged_rank` = `input_ragged_rank` + `rank(variant)` does
1989          not hold.
1990    """
1991    variant = ops.convert_to_tensor(
1992        variant, name="variant", dtype=dtypes.variant)
1993    if (variant.shape.ndims is not None and input_ragged_rank is not None and
1994        output_ragged_rank != input_ragged_rank + variant.shape.ndims):
1995      raise ValueError(
1996          f"Argument `output_ragged_rank` ({output_ragged_rank}) must be equal "
1997          f"to `input_ragged_rank` + `variant.shape.ndims` "
1998          f"({input_ragged_rank} + {variant.shape.ndims}).")
1999    input_ragged_rank = -1 if input_ragged_rank is None else input_ragged_rank
2000    with ops.name_scope(
2001        name, "RaggedFromVariant",
2002        [variant, dtype, input_ragged_rank, output_ragged_rank]):
2003      result = gen_ragged_conversion_ops.ragged_tensor_from_variant(
2004          variant, input_ragged_rank, output_ragged_rank, dtype,
2005          row_splits_dtype, name)
2006      return cls.from_nested_row_splits(
2007          result.output_dense_values,
2008          result.output_nested_splits,
2009          validate=False)
2010
2011  def _to_variant(self, batched_input=False, name=None):
2012    """Converts this `RaggedTensor` into a `variant` Tensor.
2013
2014    If `batched_input` is `True`, then the `RaggedTensor` is unbatched along the
2015    zero-th dimension, each component `RaggedTensor` is encoded into a scalar
2016    `variant` Tensor, and these are stacked to return a 1-D `variant` Tensor.
2017    If `batched_input` is `False`, then the `RaggedTensor` is encoded as is and
2018    a scalar `variant` Tensor is returned.
2019
2020    Example:
2021    >>> rt = tf.ragged.constant([[[0]], [[1]], [[2]]])
2022    >>> rt._to_variant().shape.as_list()
2023    []
2024    >>> rt._to_variant(batched_input=True).shape.as_list()
2025    [3]
2026
2027    Args:
2028      batched_input: If `True`, the `RaggedTensor` is unbatched and converted to
2029        a `variant` vector. Set to `False` by default.
2030      name: A name prefix for the returned tensors (optional).
2031
2032    Returns:
2033      A `variant` Tensor that encodes this `RaggedTensor`.
2034    """
2035    with ops.name_scope(name, "RaggedToVariant", [self, batched_input]):
2036      return gen_ragged_conversion_ops.ragged_tensor_to_variant(
2037          self.nested_row_splits, self.flat_values, batched_input, name)
2038
2039  #=============================================================================
2040  # String Encoding
2041  #=============================================================================
2042  def __repr__(self):
2043    if self._is_eager():
2044      return "<tf.RaggedTensor %s>" % self.to_list()
2045    else:
2046      return "tf.RaggedTensor(values=%s, row_splits=%s)" % (self.values,
2047                                                            self.row_splits)
2048
2049  #=============================================================================
2050  # Eager Execution Mode
2051  #=============================================================================
2052
2053  def numpy(self):
2054    """Returns a numpy `array` with the values for this `RaggedTensor`.
2055
2056    Requires that this `RaggedTensor` was constructed in eager execution mode.
2057
2058    Ragged dimensions are encoded using numpy `arrays` with `dtype=object` and
2059    `rank=1`, where each element is a single row.
2060
2061    #### Examples
2062
2063    In the following example, the value returned by `RaggedTensor.numpy()`
2064    contains three numpy `array` objects: one for each row (with `rank=1` and
2065    `dtype=int64`), and one to combine them (with `rank=1` and `dtype=object`):
2066
2067    >>> tf.ragged.constant([[1, 2, 3], [4, 5]], dtype=tf.int64).numpy()
2068    array([array([1, 2, 3]), array([4, 5])], dtype=object)
2069
2070    Uniform dimensions are encoded using multidimensional numpy `array`s.  In
2071    the following example, the value returned by `RaggedTensor.numpy()` contains
2072    a single numpy `array` object, with `rank=2` and `dtype=int64`:
2073
2074    >>> tf.ragged.constant([[1, 2, 3], [4, 5, 6]], dtype=tf.int64).numpy()
2075    array([[1, 2, 3], [4, 5, 6]])
2076
2077    Returns:
2078      A numpy `array`.
2079    """
2080    if not self._is_eager():
2081      raise ValueError("RaggedTensor.numpy() is only supported in eager mode.")
2082    values = self.values.numpy()
2083    splits = self.row_splits.numpy()
2084    rows = [values[splits[i]:splits[i + 1]] for i in range(len(splits) - 1)]
2085    if not rows:
2086      return np.zeros((0, 0) + values.shape[1:], dtype=values.dtype)
2087    # Note: if `rows` have ragged lengths, then they will be stored in a
2088    # np.ndarray with dtype=object and rank=1.  If they have uniform lengths,
2089    # they will be combined into a single np.ndarray with dtype=row.dtype and
2090    # rank=row.rank+1.
2091    return np.array(rows)
2092
2093  def to_list(self):
2094    """Returns a nested Python `list` with the values for this `RaggedTensor`.
2095
2096    Requires that `rt` was constructed in eager execution mode.
2097
2098    Returns:
2099      A nested Python `list`.
2100    """
2101    if not isinstance(self.row_splits, ops.EagerTensor):
2102      raise ValueError("to_list can only be used in eager mode.")
2103    row_splits = self.row_splits.numpy().tolist()
2104    values = self.values
2105
2106    if isinstance(values, RaggedTensor):
2107      return [
2108          values[row_splits[i]:row_splits[i + 1]].to_list()
2109          for i in range(len(row_splits) - 1)
2110      ]
2111    else:
2112      # Convert values to a Python list.
2113      if hasattr(values, "numpy"):
2114        values_as_list = values.numpy().tolist()
2115      elif hasattr(values, "to_list"):
2116        values_as_list = values.to_list()
2117      else:
2118        raise ValueError("values must be convertible to a list")
2119
2120      return [
2121          values_as_list[row_splits[i]:row_splits[i + 1]]
2122          for i in range(len(row_splits) - 1)
2123      ]
2124
2125  def _eager_value(self):
2126    """Returns a RaggedTensorValue for self.  Requires self._is_eager()=true."""
2127    value = self.flat_values.numpy()
2128    for row_splits in reversed(self.nested_row_splits):
2129      value = ragged_tensor_value.RaggedTensorValue(value, row_splits.numpy())
2130    return value
2131
2132  def _is_eager(self):
2133    """Returns True if values & row_splits Tensors are all `EagerTensor`s."""
2134    rt = self
2135    while isinstance(rt, RaggedTensor):
2136      if not isinstance(rt.row_splits, ops.EagerTensor):
2137        return False
2138      rt = rt.values
2139    return isinstance(rt, ops.EagerTensor)
2140
2141  #=============================================================================
2142  # Operators
2143  #=============================================================================
2144  # To avoid circular dependencies, we define stub methods for operators here,
2145  # and then override them when the ragged_operators module is imported.
2146
2147  def _overloaded_operator(name):  # pylint: disable=no-self-argument
2148
2149    def stub(*args, **kwargs):
2150      del args, kwargs
2151      raise ValueError(
2152          f"You must import 'tensorflow.python.ops.ragged.ragged_ops' "
2153          f"before using RaggedTensor.{name}.")
2154
2155    return stub
2156
2157  __getitem__ = _overloaded_operator("__getitem__")
2158  __ge__ = _overloaded_operator("__ge__")
2159  __gt__ = _overloaded_operator("__gt__")
2160  __le__ = _overloaded_operator("__le__")
2161  __lt__ = _overloaded_operator("__lt__")
2162  __and__ = _overloaded_operator("__and__")
2163  __rand__ = _overloaded_operator("__rand__")
2164  __invert__ = _overloaded_operator("__invert__")
2165  __ror__ = _overloaded_operator("__ror__")
2166  __or__ = _overloaded_operator("__or__")
2167  __xor__ = _overloaded_operator("__xor__")
2168  __rxor__ = _overloaded_operator("__rxor__")
2169  __abs__ = _overloaded_operator("__abs__")
2170  __add__ = _overloaded_operator("__add__")
2171  __radd__ = _overloaded_operator("__radd__")
2172  __div__ = _overloaded_operator("__div__")
2173  __rdiv__ = _overloaded_operator("__rdiv__")
2174  __floordiv__ = _overloaded_operator("__floordiv__")
2175  __rfloordiv__ = _overloaded_operator("__rfloordiv__")
2176  __mod__ = _overloaded_operator("__mod__")
2177  __rmod__ = _overloaded_operator("__rmod__")
2178  __mul__ = _overloaded_operator("__mul__")
2179  __rmul__ = _overloaded_operator("__rmul__")
2180  __neg__ = _overloaded_operator("__neg__")
2181  __pow__ = _overloaded_operator("__pow__")
2182  __rpow__ = _overloaded_operator("__rpow__")
2183  __sub__ = _overloaded_operator("__sub__")
2184  __rsub__ = _overloaded_operator("__rsub__")
2185  __truediv__ = _overloaded_operator("__truediv__")
2186  __rtruediv__ = _overloaded_operator("__rtruediv__")
2187  del _overloaded_operator
2188
2189  #=============================================================================
2190  # Name Scope
2191  #=============================================================================
2192
2193  # This private function is used by ops.name_scope to ensure that all of the
2194  # input tensors for the scope belong to the same graph.  Defining this means
2195  # that you may include `RaggedTensor` objects in the name_scope `values`
2196  # list.
2197  def _as_graph_element(self):
2198    """Convert `self` to a graph element."""
2199    values = self.values
2200    while isinstance(values, RaggedTensor):
2201      values = values.values
2202    return values
2203
2204  #=============================================================================
2205  # Composite Tensor
2206  #=============================================================================
2207
2208  @property
2209  def _type_spec(self):
2210    return RaggedTensorSpec.from_value(self)
2211
2212  def _shape_invariant_to_type_spec(self, shape):
2213    return RaggedTensorSpec(shape, self.dtype, self.ragged_rank,
2214                            self.row_splits.dtype)
2215
2216  def consumers(self):
2217    return self._consumers()
2218
2219
2220def is_ragged(value):
2221  """Returns true if `value` is a ragged tensor or ragged tensor value."""
2222  return isinstance(value,
2223                    (RaggedTensor, ragged_tensor_value.RaggedTensorValue))
2224
2225
2226def match_row_splits_dtypes(*tensors, **kwargs):
2227  """Return a copy of `tensors` with row_splits all having the same dtype.
2228
2229  Args:
2230    *tensors: A list of Tensors or RaggedTensors.
2231    **kwargs: If 'return_dtype=True', then return a tuple (dtype, tensors),
2232      where `dtype` is the data type used by row-splits, and `tensors` is the
2233      converted list of `Tensors` and `RaggedTensors`.
2234
2235  Returns:
2236    The converted list of `Tensors` and `RaggedTensors`.
2237  """
2238  return_dtype = kwargs.pop("return_dtype", False)
2239  if kwargs:
2240    raise ValueError(f"Unexpected keyword args {kwargs}.")
2241
2242  has_int32 = False
2243  has_int64 = False
2244  for tensor in tensors:
2245    if isinstance(tensor, RaggedTensor):
2246      if tensor.row_splits.dtype == dtypes.int32:
2247        has_int32 = True
2248      else:
2249        has_int64 = True
2250
2251  if has_int32 and has_int64:
2252    if not ragged_config.auto_cast_partition_dtype():
2253      raise ValueError("Input RaggedTensors have mismatched row_splits dtypes; "
2254                       "use RaggedTensor.with_row_splits_dtype() to convert "
2255                       "them to compatible dtypes.")
2256    dtype = dtypes.int64
2257    tensors = tuple(
2258        t.with_row_splits_dtype(dtypes.int64) if isinstance(t, RaggedTensor
2259                                                           ) else t
2260        for t in tensors)
2261
2262  elif has_int32:
2263    dtype = dtypes.int32
2264  else:
2265    dtype = dtypes.int64
2266
2267  if return_dtype:
2268    return (dtype, tensors)
2269  else:
2270    return tensors
2271
2272
2273#===============================================================================
2274# RaggedTensorSpec
2275#===============================================================================
2276@tf_export("RaggedTensorSpec")
2277@type_spec.register("tf.RaggedTensorSpec")
2278class RaggedTensorSpec(type_spec.BatchableTypeSpec):
2279  """Type specification for a `tf.RaggedTensor`."""
2280
2281  __slots__ = [
2282      "_shape", "_dtype", "_ragged_rank", "_row_splits_dtype",
2283      "_flat_values_spec"
2284  ]
2285
2286  @property
2287  def dtype(self):
2288    """The `tf.dtypes.DType` specified by this type for the RaggedTensor.
2289
2290    Examples:
2291
2292    >>> rt = tf.ragged.constant([["a"], ["b", "c"]], dtype=tf.string)
2293    >>> tf.type_spec_from_value(rt).dtype
2294    tf.string
2295
2296    Returns:
2297      A `tf.dtypes.DType` of the values in the RaggedTensor.
2298    """
2299    return self._dtype
2300
2301  @property
2302  def shape(self):
2303    """The statically known shape of the RaggedTensor.
2304
2305    Examples:
2306
2307    >>> rt = tf.ragged.constant([[0], [1, 2]])
2308    >>> tf.type_spec_from_value(rt).shape
2309    TensorShape([2, None])
2310
2311    >>> rt = tf.ragged.constant([[[0, 1]], [[1, 2], [3, 4]]], ragged_rank=1)
2312    >>> tf.type_spec_from_value(rt).shape
2313    TensorShape([2, None, 2])
2314
2315    Returns:
2316      A `tf.TensorShape` containing the statically known shape of the
2317      RaggedTensor. Ragged dimensions have a size of `None`.
2318    """
2319    return self._shape
2320
2321  @property
2322  def ragged_rank(self):
2323    """The number of times the RaggedTensor's flat_values is partitioned.
2324
2325    Defaults to `shape.ndims - 1`.
2326
2327    Examples:
2328
2329    >>> values = tf.ragged.constant([[1, 2, 3], [4], [5, 6], [7, 8, 9, 10]])
2330    >>> tf.type_spec_from_value(values).ragged_rank
2331    1
2332
2333    >>> rt1 = tf.RaggedTensor.from_uniform_row_length(values, 2)
2334    >>> tf.type_spec_from_value(rt1).ragged_rank
2335    2
2336
2337    Returns:
2338      A Python `int` indicating the number of times the underlying `flat_values`
2339      Tensor has been partitioned to add a new dimension.
2340      I.e., `tf.rank(rt) = tf.rank(rt.flat_values) + rt.ragged_rank`.
2341    """
2342    return self._ragged_rank
2343
2344  @property
2345  def row_splits_dtype(self):
2346    """The `tf.dtypes.DType` of the RaggedTensor's `row_splits`.
2347
2348    Examples:
2349
2350    >>> rt = tf.ragged.constant([[1, 2, 3], [4]], row_splits_dtype=tf.int64)
2351    >>> tf.type_spec_from_value(rt).row_splits_dtype
2352    tf.int64
2353
2354    Returns:
2355      A `tf.dtypes.DType` for the RaggedTensor's `row_splits` tensor. One
2356      of `tf.int32` or `tf.int64`.
2357    """
2358    return self._row_splits_dtype
2359
2360  @property
2361  def flat_values_spec(self):
2362    """The `TypeSpec` of the flat_values of RaggedTensor.
2363
2364    Returns:
2365      - The TypeSpec of flat_values.
2366      - None when the flat_values is a Tensor.
2367    """
2368    return self._flat_values_spec
2369
2370  @property
2371  def value_type(self):
2372    return RaggedTensor if self._ragged_rank > 0 else ops.Tensor
2373
2374  def __init__(self,
2375               shape=None,
2376               dtype=dtypes.float32,
2377               ragged_rank=None,
2378               row_splits_dtype=dtypes.int64,
2379               flat_values_spec=None):
2380    """Constructs a type specification for a `tf.RaggedTensor`.
2381
2382    Args:
2383      shape: The shape of the RaggedTensor, or `None` to allow any shape.  If a
2384        shape is specified, then all ragged dimensions must have size `None`.
2385      dtype: `tf.DType` of values in the RaggedTensor.
2386      ragged_rank: Python integer, the number of times the RaggedTensor's
2387        flat_values is partitioned.  Defaults to `shape.ndims - 1`.
2388      row_splits_dtype: `dtype` for the RaggedTensor's `row_splits` tensor. One
2389        of `tf.int32` or `tf.int64`.
2390      flat_values_spec: TypeSpec for flat_value of the RaggedTensor. It shall be
2391        provided when the flat_values is a CompositeTensor rather then Tensor.
2392        If both `dtype` and `flat_values_spec` and  are provided, `dtype` must
2393        be the same as `flat_values_spec.dtype`. (experimental)
2394    """
2395    self._shape = tensor_shape.as_shape(shape)
2396    self._row_splits_dtype = dtypes.as_dtype(row_splits_dtype)
2397    if flat_values_spec is not None:
2398      if dtype is None:
2399        dtype = flat_values_spec.dtype
2400      elif dtype != flat_values_spec.dtype:
2401        raise ValueError("dtype must be the same as flat_values_spec.dtype")
2402    elif dtype is None:
2403      raise ValueError(
2404          "At least one of dtype or flat_values_spec must be provided")
2405    self._dtype = dtypes.as_dtype(dtype)
2406    self._flat_values_spec = flat_values_spec
2407
2408    rank = self._shape.ndims
2409    if ragged_rank is None:
2410      if rank is None:
2411        raise ValueError("Must specify ragged_rank or "
2412                         "a shape with a known rank.")
2413      ragged_rank = rank - 1
2414    self._ragged_rank = ragged_rank
2415    if not isinstance(self._ragged_rank, int):
2416      raise TypeError(f"Argument `ragged_rank` must be an int. "
2417                      f"Recieved {ragged_rank}.")
2418
2419    if rank is not None:
2420      if ragged_rank >= rank:
2421        raise ValueError(f"Argument `ragged_rank` ({ragged_rank}) must be less "
2422                         f"than rank ({rank}).")
2423
2424  def is_compatible_with(self, spec_or_value):
2425    # RaggedTensor with ragged_rank 0 can be compatible with raw flat_values.
2426    if self._ragged_rank == 0:
2427      if self._flat_values_spec is None:
2428        if isinstance(spec_or_value, (ops.Tensor, tensor_spec.TensorSpec)):
2429          return tensor_spec.TensorSpec(
2430              self._shape, self._dtype).is_compatible_with(spec_or_value)
2431      elif not isinstance(spec_or_value, (RaggedTensor, RaggedTensorSpec)):
2432        return self._flat_values_spec.is_compatible_with(spec_or_value)
2433    return super(RaggedTensorSpec, self).is_compatible_with(spec_or_value)
2434
2435  def _serialize(self):
2436    if self._flat_values_spec is None:
2437      return (self._shape, self._dtype, self._ragged_rank,
2438              self._row_splits_dtype)
2439    else:
2440      return (self._shape, self._dtype, self._ragged_rank,
2441              self._row_splits_dtype, self._flat_values_spec)
2442
2443  @property
2444  def _component_specs(self):
2445    if self._ragged_rank == 0:
2446      if self._flat_values_spec is not None:
2447        return [self._flat_values_spec]
2448      else:
2449        return [tensor_spec.TensorSpec(self._shape, self._dtype)]
2450
2451    flat_values_spec = self._flat_values_spec
2452    if flat_values_spec is None:
2453      flat_values_shape = tensor_shape.TensorShape([None]).concatenate(
2454          self._shape[self._ragged_rank + 1:])
2455      flat_values_spec = tensor_spec.TensorSpec(flat_values_shape, self._dtype)
2456    outer_dim = tensor_shape.dimension_at_index(self._shape, 0)
2457    outer_splits_shape = [None if outer_dim is None else outer_dim + 1]
2458    inner_splits_spec = tensor_spec.TensorSpec([None], self._row_splits_dtype)
2459
2460    specs = ([
2461        flat_values_spec,
2462        tensor_spec.TensorSpec(outer_splits_shape, self._row_splits_dtype)
2463    ] + [inner_splits_spec for _ in range(self._ragged_rank - 1)])
2464    return specs
2465
2466  def _to_components(self, value):
2467    if is_ragged(value):
2468      return [value.flat_values] + list(value.nested_row_splits)
2469    else:
2470      return [value]
2471
2472  def _from_components(self, tensor_list):
2473    result = tensor_list[0]
2474    if (all(isinstance(t, np.ndarray) for t in tensor_list) and
2475        not tf2.enabled()):
2476      for row_splits in reversed(tensor_list[1:]):
2477        result = ragged_tensor_value.RaggedTensorValue(result, row_splits)
2478    else:
2479      if isinstance(tensor_list[0], np.ndarray):
2480        tensor_list = [ops.convert_to_tensor(t) for t in tensor_list]
2481        result = tensor_list[0]
2482      for row_splits in reversed(tensor_list[1:]):
2483        result = RaggedTensor(
2484            result,
2485            RowPartition.from_row_splits(row_splits, validate=False),
2486            internal=True)
2487    return result
2488
2489  # The RaggedTensorSpec tensor_list encoding uses to/from_variant ops
2490  # to (un)box the component tensors in a way that allows for batching &
2491  # unbatching.
2492  @property
2493  def _flat_tensor_specs(self):
2494    # NOTE(mishragaurav): The default flat shape of a boxed `RaggedTensor` is
2495    # `[]` (scalar), but a `RaggedTensorSpec` can also represent a batch of
2496    # boxed `RaggedTensor` objects with shape `(...)` (and batches of batches,
2497    # etc.), so the flat shape must be unknown.
2498    return [tensor_spec.TensorSpec(None, dtypes.variant)]
2499
2500  def _to_tensor_list(self, value):
2501    # TODO(edloper): Update gen_ragged_conversion_ops that convert to and
2502    # from variant to include all of the row-partitioning tensors.
2503    if self._flat_values_spec is not None:
2504      raise ValueError("Customized value_type is not supported.")
2505    ragged_rank = value.ragged_rank if isinstance(value, RaggedTensor) else 0
2506    if ragged_rank != self._ragged_rank:
2507      raise ValueError(f"Ragged rank of value {ragged_rank} does not match "
2508                       f"ragged rank of type {self._ragged_rank}.")
2509    if ragged_rank == 0:
2510      return [
2511          gen_ragged_conversion_ops.ragged_tensor_to_variant(
2512              (), value, batched_input=False)
2513      ]
2514    # pylint: disable=protected-access
2515    return [value._to_variant(batched_input=False)]
2516
2517  def _to_batched_tensor_list(self, value):
2518    if self._flat_values_spec is not None:
2519      raise ValueError("Customized value_type is not supported.")
2520    ragged_rank = value.ragged_rank if isinstance(value, RaggedTensor) else 0
2521    if ragged_rank != self._ragged_rank:
2522      raise ValueError(f"Ragged rank of value {ragged_rank} does not match "
2523                       f"ragged rank of type {self._ragged_rank}.")
2524    if ragged_rank == 0:
2525      # TODO(b/141789000) Update this to handle ragged_rank=0.
2526      raise ValueError(
2527          "_to_batched_tensor_list doesn't support ragged_rank=0 yet")
2528    # pylint: disable=protected-access
2529    return [value._to_variant(batched_input=True)]
2530
2531  def _from_compatible_tensor_list(self, tensor_list):
2532    if self._flat_values_spec is not None:
2533      raise ValueError("Customized value_type is not supported.")
2534    if self._ragged_rank < 0:
2535      raise ValueError(f"Argument `ragged_rank` must be non-negative. "
2536                       f"Received {self._ragged_rank}.")
2537    result = RaggedTensor._from_variant(  # pylint: disable=protected-access
2538        tensor_list[0],
2539        dtype=self._dtype,
2540        row_splits_dtype=self._row_splits_dtype,
2541        output_ragged_rank=self._ragged_rank)
2542    if self._shape.ndims is not None:
2543      if isinstance(result, RaggedTensor):
2544        outer_dim = tensor_shape.dimension_value(self._shape[0])
2545        if outer_dim is not None:
2546          result.row_splits.set_shape([outer_dim + 1])
2547        result._set_shape(self._shape)  # pylint: disable=protected-access
2548      else:
2549        result.set_shape(self._shape)
2550    return result
2551
2552  def _batch(self, batch_size):
2553    if self._flat_values_spec is not None:
2554      raise ValueError("Customized value_type is not supported.")
2555    return RaggedTensorSpec(
2556        tensor_shape.TensorShape([batch_size]).concatenate(self._shape),
2557        self._dtype, self._ragged_rank + 1, self._row_splits_dtype)
2558
2559  def _unbatch(self):
2560    if self._flat_values_spec is not None:
2561      raise ValueError("Customized value_type is not supported.")
2562    # Note: Negative ragged_rank is allowed here because the dataset could be
2563    # subsequently batched again. If ragged_rank > 1, assume row_splits_dtype is
2564    # consistent. Errors are handled in
2565    # RaggedTensorSpec._from_compatible_tensor_list()
2566    return RaggedTensorSpec(self._shape[1:], self._dtype, self._ragged_rank - 1,
2567                            self._row_splits_dtype)
2568
2569  def _to_legacy_output_types(self):
2570    return self._dtype
2571
2572  def _to_legacy_output_shapes(self):
2573    return self._shape
2574
2575  def _to_legacy_output_classes(self):
2576    return self
2577
2578  @classmethod
2579  def from_value(cls, value):
2580    if (isinstance(value, ragged_tensor_value.RaggedTensorValue) or
2581        isinstance(value.flat_values, ops.Tensor)):
2582      return cls(
2583          shape=value.shape,
2584          dtype=value.values.dtype,
2585          ragged_rank=value.ragged_rank,
2586          row_splits_dtype=value.row_splits.dtype)
2587    else:
2588      return cls(
2589          shape=value.shape,
2590          dtype=value.values.dtype,
2591          ragged_rank=value.ragged_rank,
2592          row_splits_dtype=value.row_splits.dtype,
2593          flat_values_spec=type_spec.type_spec_from_value(value.flat_values))
2594
2595
2596type_spec.register_type_spec_from_value_converter(
2597    ragged_tensor_value.RaggedTensorValue, RaggedTensorSpec.from_value)
2598
2599
2600#===============================================================================
2601# Convert value -> tensor
2602#===============================================================================
2603def convert_to_tensor_or_ragged_tensor(value,
2604                                       dtype=None,
2605                                       preferred_dtype=None,
2606                                       name=None):
2607  """Converts value to a `RaggedTensor` or `Tensor`.
2608
2609  * If `value` is a `RaggedTensor`, then return it as-is.
2610  * If `value` is a `RaggedTensorValue`, return a corresponding constant
2611    `RaggedTensor`.
2612  * Otherwise, use `convert_to_tensor` to convert `value` to a `Tensor`.
2613
2614  Args:
2615    value: A `RaggedTensor`, a `RaggedTensorValue`, or an object whose type has
2616      a registered `Tensor` conversion function.
2617    dtype: Optional element type for the returned tensor.  If missing the type
2618      is inferred from the type of `value`.
2619    preferred_dtype: Optional element type for the returned tensor, used when
2620      dtype is None.  This argument has no effect if `value` is already a
2621      tensor, or when conversion is not possible.
2622    name: Optional name to use if a new `Tensor` is created.
2623
2624  Returns:
2625    A `Tensor` or `RaggedTensor`.
2626  """
2627  if isinstance(value, RaggedTensor):
2628    if dtype and not dtype.is_compatible_with(value.dtype):
2629      raise ValueError(f"Tensor conversion requested dtype {dtype.name} for "
2630                       f"RaggedTensor with dtype {value.dtype.name}: {value}.")
2631    return value
2632  elif isinstance(value, ragged_tensor_value.RaggedTensorValue):
2633    with ops.name_scope(name, "ConvertToTensorOrRaggedTensor", []):
2634      flat_values = ops.convert_to_tensor(
2635          value=value.flat_values,
2636          dtype=dtype,
2637          preferred_dtype=preferred_dtype,
2638          name="flat_values")
2639      return RaggedTensor.from_nested_row_splits(
2640          flat_values, value.nested_row_splits, validate=False)
2641  else:
2642    return ops.convert_to_tensor_v2_with_dispatch(
2643        value=value, dtype=dtype, dtype_hint=preferred_dtype, name=name)
2644
2645
2646def _convert_to_ragged_tensor_values(value):
2647  """Converts value to supported RaggedTensor value.
2648
2649  * If `value` is an object of supported value type, then return it as-is.
2650  * Otherwise convert it to Tensor or RaggedTensor.
2651
2652  Args:
2653    value: An object of `Tensor`, `RaggedTensor` or registerred RaggedTensor
2654      value types, or an object whose type has a registered `Tensor` conversion
2655      function.
2656
2657  Returns:
2658    An object of `Tensor`, `RaggedTensor` or registerred RaggedTensor
2659    value types
2660  """
2661  if _is_supported_ragged_values_type(value):
2662    return value
2663  else:
2664    return convert_to_tensor_or_ragged_tensor(value, name="values")
2665
2666
2667#===============================================================================
2668# Register RaggedTensor for use with session.run.
2669#===============================================================================
2670def _ragged_tensor_value_from_components(components):
2671  components = list(components)
2672  value = components.pop()
2673  while components:
2674    value = ragged_tensor_value.RaggedTensorValue(value, components.pop())
2675  return value
2676
2677
2678def _ragged_tensor_session_fetch(rt):
2679  components = rt.nested_row_splits + (rt.flat_values,)
2680  return (components, _ragged_tensor_value_from_components)
2681
2682
2683def _ragged_tensor_session_feed(feed_key, feed_val):
2684  key_components = feed_key.nested_row_splits + (feed_key.flat_values,)
2685  val_components = feed_val.nested_row_splits + (feed_val.flat_values,)
2686  return zip(key_components, val_components)
2687
2688
2689def _ragged_tensor_session_feed_for_partial_run(feed_key):
2690  return feed_key.nested_row_splits + (feed_key.flat_values,)
2691
2692
2693session.register_session_run_conversion_functions(
2694    RaggedTensor, _ragged_tensor_session_fetch, _ragged_tensor_session_feed,
2695    _ragged_tensor_session_feed_for_partial_run)
2696
2697
2698#===============================================================================
2699# RaggedTensorType
2700#===============================================================================
2701class RaggedTensorType(object):
2702  """Encoding of a static type for a `RaggedTensor`.
2703
2704  Use this type to express/declare that an output must have the type of
2705  `RaggedTensor`.
2706  """
2707
2708  def __init__(self, dtype, ragged_rank, row_splits_dtype=dtypes.int64):
2709    """Initializes a RaggedTensorType object.
2710
2711    Args:
2712      dtype: data type of the `RaggedTensor`'s inner values.
2713      ragged_rank: ragged_rank of the declared `RaggedTensor`.
2714      row_splits_dtype: data type for the `RaggedTensor`'s row splits.
2715        One of: `tf.int32` or `tf.int64`.
2716    """
2717    row_splits_dtype = dtypes.as_dtype(row_splits_dtype)
2718    self._dtype = dtype
2719    self._ragged_rank = ragged_rank
2720    self._row_splits_dtype = row_splits_dtype
2721
2722  dtype = property(lambda self: self._dtype)
2723  ragged_rank = property(lambda self: self._ragged_rank)
2724  row_splits_dtype = property(lambda self: self._row_splits_dtype)
2725
2726  def __repr__(self):
2727    return "RaggedTensorType(%r, %r, %r)" % (self.dtype, self.ragged_rank,
2728                                             self.row_splits_dtype)
2729
2730
2731#===============================================================================
2732# Helper Functions
2733#===============================================================================
2734def _assert_sparse_indices_are_ragged_right(indices):
2735  """Checks that the given SparseTensor.indices tensor is ragged-right.
2736
2737  Example: `indices = [[0, 0], [0, 1], [2, 0], [3, 1]]` is not ragged right
2738  because the entry `[3, 1]` skips a cell.
2739
2740  Args:
2741    indices: The SparseTensor indices to check.
2742
2743  Returns:
2744    A list of control dependency op tensors.
2745  """
2746  index_prefix = indices[:, :-1]
2747  index_suffix = indices[:, -1]
2748
2749  # Check whether each index is starting a new row in the innermost dimension
2750  # (prefix[i] != prefix[i-1]) or continuing a row (prefix[i] == prefix[i-1]).
2751  # (Note: this skips the first index; we will check that separately below.)
2752  index_prefix_changed = math_ops.reduce_any(
2753      math_ops.not_equal(index_prefix[1:], index_prefix[:-1]), axis=1)
2754
2755  # Check two cases:
2756  #   * For indices that start a new row: index_suffix[i] must be zero.
2757  #   * For indices that continue a row: index_suffix[i] must be equal to
2758  #     index_suffix[i-1]+1.
2759  index_ok = array_ops.where(
2760      index_prefix_changed, math_ops.equal(index_suffix[1:], 0),
2761      math_ops.equal(index_suffix[1:], index_suffix[:-1] + 1))
2762
2763  # Also check that the very first index didn't skip any cells.  The first
2764  # index starts a new row (by definition), so its suffix should be zero.
2765  sparse_indices_are_ragged_right = math_ops.logical_and(
2766      math_ops.reduce_all(math_ops.equal(index_suffix[:1], 0)),
2767      math_ops.reduce_all(index_ok))
2768
2769  message = [
2770      "SparseTensor is not right-ragged", "SparseTensor.indices =", indices
2771  ]
2772  return [control_flow_ops.Assert(sparse_indices_are_ragged_right, message)]
2773
2774
2775@ops.RegisterGradient("RaggedTensorToSparse")
2776def _ragged_tensor_to_sparse_gradient(op, unused_sparse_indices_grad,
2777                                      sparse_values_grad,
2778                                      unused_sparse_shape_grad):
2779  """Gradient for RaggedTensorToSparse."""
2780  op_inputs_nested_row_splits = op.inputs[:-1]
2781  op_inputs_flat_values = op.inputs[-1]
2782
2783  # No gradient for the RaggedTensor's nested_row_splits.
2784  nested_row_splits_gradient = [None] * len(op_inputs_nested_row_splits)
2785
2786  # Gradient for the RaggedTensor's flat_values is formed by reshaping
2787  # the gradient for the SparseTensor's values.
2788  flat_values_shape = array_ops.shape(op_inputs_flat_values)
2789  flat_values_gradient = array_ops.reshape(sparse_values_grad,
2790                                           flat_values_shape)
2791
2792  return nested_row_splits_gradient + [flat_values_gradient]
2793
2794
2795def _assert_monotonic_increasing(tensor, message=None):
2796  return check_ops.assert_non_negative(
2797      tensor[1:] - tensor[:-1], message=message)
2798
2799
2800def _assert_zero(tensor, message=None):
2801  return check_ops.assert_equal(
2802      tensor, constant_op.constant(0, dtype=tensor.dtype), message=message)
2803
2804
2805def _nrows(tensor, out_type=dtypes.int32):
2806  if isinstance(tensor, RaggedTensor):
2807    return tensor.nrows(out_type=out_type)
2808  else:
2809    return array_ops.shape(tensor, out_type=out_type)[0]
2810
2811
2812def merge_dims(value, outer_axis, inner_axis):
2813  """Merges value[outer_axis...inner_axis] into a single dimension.
2814
2815  See `RaggedTensor.merge_dims()` for more details.  This helper differs from
2816  `RaggedTensor.merge_dims()` in that `value` may be a dense or ragged tensor.
2817
2818  Args:
2819    value: A `RaggedTensor` or `Tensor`
2820    outer_axis: `int`
2821    inner_axis: `int`
2822
2823  Returns:
2824    A flattened `RaggedTensor` or `Tensor`.
2825  """
2826  if outer_axis == inner_axis:
2827    return value
2828
2829  # Flatten outer dimensions of a RaggedTensor by just taking its values.
2830  while outer_axis == 0 and isinstance(value, RaggedTensor):
2831    value = value.values
2832    inner_axis -= 1
2833    if inner_axis == 0:
2834      return value
2835
2836  # Flatten non-Ragged tensors using tf.reshape().
2837  if not isinstance(value, RaggedTensor):
2838    if value.shape.is_fully_defined():
2839      old_shape = value.shape.as_list()
2840      new_shape = old_shape[:outer_axis] + [-1] + old_shape[inner_axis + 1:]
2841    else:
2842      old_shape = array_ops.shape(value)
2843      new_shape = array_ops.concat(
2844          [old_shape[:outer_axis], [-1], old_shape[inner_axis + 1:]], axis=0)
2845    return array_ops.reshape(value, new_shape)
2846
2847  # Handle outer_axis>1 via recursion.
2848  if outer_axis > 1:
2849    return value.with_values(
2850        merge_dims(value.values, outer_axis - 1, inner_axis - 1))
2851
2852  # At this point, we know outer_axis == 1, and value is a RaggedTensor.
2853  # So we need to flatten the values and build a corresponding splits tensor.
2854  new_values = value.values
2855  new_splits = value.row_splits
2856  for axis in range(outer_axis, inner_axis):
2857    if isinstance(new_values, RaggedTensor):
2858      # Flatten a single ragged dimension.
2859      new_splits = array_ops.gather(new_values.row_splits, new_splits)
2860      new_values = new_values.values
2861    else:
2862      # Flatten all remaining dense dimensions.
2863      shape_split = inner_axis - axis + 1
2864      if new_values.shape.is_fully_defined():
2865        old_shape = new_values.shape.as_list()
2866        new_shape = [-1] + old_shape[shape_split:]
2867        flat_size = _prod(old_shape[1:shape_split])
2868      else:
2869        old_shape = array_ops.shape(new_values)
2870        new_shape = array_ops.concat([[-1], old_shape[shape_split:]], axis=0)
2871        flat_size = math_ops.cast(
2872            math_ops.reduce_prod(old_shape[1:shape_split]), new_splits.dtype)
2873      new_values = array_ops.reshape(new_values, new_shape)
2874      new_splits = new_splits * flat_size
2875      break
2876  return RaggedTensor.from_row_splits(new_values, new_splits)
2877
2878
2879def _prod(lst):
2880  """Returns the product of the numbers in a list."""
2881  return functools.reduce(operator.mul, lst, 1)
2882
2883
2884def _get_row_partition_type_tensor_pairs_tail(partition):
2885  """Gets a row partition type tensor pair for the tail.
2886
2887  If value_rowid is defined, then it is used. Otherwise, row_splits
2888  are used.
2889
2890  Args:
2891    partition: a RowPartition.
2892
2893  Returns:
2894    A list of (row_partition_type, row_partition_tensor) pairs.
2895  """
2896  if partition.has_precomputed_value_rowids():
2897    return ("VALUE_ROWIDS", partition.value_rowids())
2898  else:
2899    return ("ROW_SPLITS", partition.row_splits())
2900
2901
2902def _get_row_partition_type_tensor_pairs(rt_input):
2903  """Gets a list of the row partitions for rt_input.
2904
2905  If value_rowids are defined, then they are used. Otherwise, row_splits
2906  are used. If the outermost level has value_rowids defind, then nrows is
2907  also added.
2908
2909  Args:
2910    rt_input: a ragged tensor.
2911
2912  Returns:
2913    A list of (row_partition_type, row_partition_tensor) pairs.
2914  """
2915  partitions = rt_input._nested_row_partitions  # pylint: disable=protected-access
2916  tail = [_get_row_partition_type_tensor_pairs_tail(x) for x in partitions[1:]]
2917
2918  if partitions[0]._value_rowids is not None:  # pylint: disable=protected-access
2919    return [("FIRST_DIM_SIZE", partitions[0].nrows()),
2920            ("VALUE_ROWIDS", partitions[0].value_rowids())] + tail
2921  else:
2922    return [("ROW_SPLITS", partitions[0].row_splits())] + tail
2923
2924
2925def _shape_as_tensor(shape, dtype):
2926  """Takes shape and coerces it to a shape as a tensor.
2927
2928  If the object is already a tensor, simply passes it on (result is guaranteed
2929  to be int64 or int32, but not necessarily dtype).
2930  If not, creates a tensor of type dtype.
2931
2932  Result is either a scalar equal to -1 if the shape is unknown_rank.
2933  Otherwise, it is a vector, where unknown dimensions are represented with a
2934  value of -1.
2935
2936  In C++, see TensorShapeFromTensor for parsing shapes in kernels, and
2937  InferenceContext::MakeShapeFromShapeTensorTreatScalarAsUnknownShape, for
2938  use in the shape inference function.
2939
2940  Args:
2941    shape: input to coerce from TensorShape, Tensor, None, List[Optional[Int]],
2942      Tuple[Optional[Int]].
2943    dtype: tf.int64 or tf.int32
2944
2945  Returns:
2946    a scalar or vector tensor of dtype tf.int32 or tf.int64.
2947  """
2948  if dtype != dtypes.int64 and dtype != dtypes.int32:
2949    raise ValueError(f"Expected int64 or int32 for dtype: got {dtype}.")
2950
2951  if isinstance(shape, ops.Tensor):
2952    if shape.dtype != dtypes.int64 and shape.dtype != dtypes.int32:
2953      return math_ops.cast(shape, dtype)
2954    return shape
2955  shape = tensor_shape.as_shape(shape)
2956  if not shape:
2957    # Imply rank is unknown using a -1 scalar.
2958    return constant_op.constant(-1, dtype=dtype)
2959  shape = [(-1 if x is None else x) for x in shape.as_list()]
2960  # At this point, shape is List[Int].
2961  return constant_op.constant(shape, dtype=dtype)
2962
2963
2964def _nvals_uniform_row_length(values, uniform_row_length):
2965  """Get the number of values for uniform row length constructor."""
2966  const_nvals = tensor_shape.dimension_at_index(values.shape, 0).value
2967  if const_nvals is not None:
2968    nvals = constant_op.constant(const_nvals, uniform_row_length.dtype)
2969  elif isinstance(values, RaggedTensor):
2970    nvals = values.nrows(out_type=uniform_row_length.dtype)
2971  else:
2972    nvals = array_ops.shape(values, out_type=uniform_row_length.dtype)[0]
2973  return nvals
2974
2975
2976def _get_optional_partition_dtype(values):
2977  """Returns the partition dtype, or None if None exists."""
2978  if isinstance(values, RaggedTensor):
2979    # pylint: disable=protected-access
2980    return values._row_partition.dtype
2981  return None
2982
2983
2984_SUPPORTED_RAGGED_VALUE_TYPES = (ops.Tensor, RaggedTensor)
2985
2986
2987# TODO(edloper): Consider whether we should change the registry to be on
2988# TypeSpecs rather than ValueTypes.
2989def _add_supported_value_type(cls):
2990  """Register the `cls` as supported value type of RaggedTenosr.
2991
2992  The cls must be a subclass of CompositeTensor, and must support:
2993   - Properties:
2994     - x.shape
2995     - x.dtype
2996   - Methods:
2997     - x.__getitem__(idx) (method: returns a supported value type)
2998   - Ops:
2999     - tf.shape(x) -- tf.shape(x)[0] must be a tf.Tensor.
3000     - tf.tile(x)
3001     - assert_rank_at_least(x)
3002     - tf.ones_like(x)
3003     - tf.gather(params=x, indices=Tensor)
3004     - tf.add(x, y)
3005     - tf.boolean_mask(x, ...)
3006     - @TODO(edloper): Complete this list
3007
3008   Note: the following RaggedTensor, RaggedTensorSpec methods & ops are not
3009   currently supported unless `rt.values` is a RaggedTensor or a tf.Tensor:
3010     - rt.to_tensor()
3011     - rt.to_sparse_tensor()
3012     - rt._to_variant()
3013     - rt._from_variant()
3014     - tf.ragged.cross([rt])
3015     - tf.gather(params=x, indices=rt)  # rt used for indices
3016     - RaggedTensorSpec methods:
3017       - _batch
3018       - _unbatch
3019       - _to_tensor_list
3020       - _to_batched_tensor_list
3021       - _from_compatible_tensor_list
3022
3023  Args:
3024    cls: The type to be added to supported value types.
3025  """
3026  if not issubclass(cls, composite_tensor.CompositeTensor):
3027    raise ValueError(f"cls ({cls}) must be a subclass of CompositeTensor.")
3028  if not hasattr(cls, "shape"):
3029    raise ValueError("cls must support the `shape` property.")
3030  if not hasattr(cls, "dtype"):
3031    raise ValueError("cls must support the `dtype` property.")
3032  global _SUPPORTED_RAGGED_VALUE_TYPES
3033  _SUPPORTED_RAGGED_VALUE_TYPES += (cls,)
3034
3035
3036def _is_supported_ragged_values_type(value):
3037  return isinstance(value, _SUPPORTED_RAGGED_VALUE_TYPES)
3038
3039
3040def _assert_is_supported_ragged_values_type(value):
3041  if not _is_supported_ragged_values_type(value):
3042    ok_types = ", ".join(cls.__name__ for cls in _SUPPORTED_RAGGED_VALUE_TYPES)
3043    raise TypeError(f"type(values) must be one of: {ok_types}, got {value}.")
3044