• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for third_party.tensorflow.python.ops.ragged_tensor."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import functools
22from absl.testing import parameterized
23import numpy as np
24
25from tensorflow.python.data.ops import dataset_ops
26from tensorflow.python.eager import backprop
27from tensorflow.python.eager import context
28from tensorflow.python.framework import constant_op
29from tensorflow.python.framework import dtypes
30from tensorflow.python.framework import errors
31from tensorflow.python.framework import ops
32from tensorflow.python.framework import tensor_shape
33from tensorflow.python.framework import tensor_spec
34from tensorflow.python.framework import test_util
35from tensorflow.python.ops import array_grad  # pylint: disable=unused-import
36from tensorflow.python.ops import array_ops
37from tensorflow.python.ops import control_flow_ops
38from tensorflow.python.ops import gen_ragged_conversion_ops
39from tensorflow.python.ops import gradients_impl
40from tensorflow.python.ops import map_fn
41from tensorflow.python.ops import math_grad  # pylint: disable=unused-import
42from tensorflow.python.ops import math_ops
43from tensorflow.python.ops import tensor_array_grad  # pylint: disable=unused-import
44from tensorflow.python.ops.ragged import ragged_factory_ops
45from tensorflow.python.ops.ragged import ragged_math_ops
46from tensorflow.python.ops.ragged import ragged_tensor
47from tensorflow.python.ops.ragged import ragged_tensor_value
48from tensorflow.python.ops.ragged.ragged_tensor import RaggedTensor
49from tensorflow.python.ops.ragged.ragged_tensor import RaggedTensorSpec
50from tensorflow.python.ops.ragged.row_partition import RowPartition
51
52from tensorflow.python.platform import googletest
53from tensorflow.python.util import nest
54
55
56def int32array(values):
57  return np.array(values, dtype=np.int32)
58
59
60@test_util.run_all_in_graph_and_eager_modes
61class RaggedTensorTest(test_util.TensorFlowTestCase, parameterized.TestCase):
62  longMessage = True  # Property in unittest.Testcase. pylint: disable=invalid-name
63
64  #=============================================================================
65  # RaggedTensor class docstring examples
66  #=============================================================================
67
68  def testClassDocStringExamples(self):
69    # From section: "Component Tensors"
70    rt = RaggedTensor.from_row_splits(
71        values=[3, 1, 4, 1, 5, 9, 2, 6], row_splits=[0, 4, 4, 7, 8, 8])
72    self.assertAllEqual(rt, [[3, 1, 4, 1], [], [5, 9, 2], [6], []])
73    del rt
74
75    # From section: "Alternative Row-Partitioning Schemes"
76    values = [3, 1, 4, 1, 5, 9, 2, 6]
77    rt1 = RaggedTensor.from_row_splits(values, row_splits=[0, 4, 4, 7, 8, 8])
78    rt2 = RaggedTensor.from_row_lengths(values, row_lengths=[4, 0, 3, 1, 0])
79    rt3 = RaggedTensor.from_value_rowids(
80        values, value_rowids=[0, 0, 0, 0, 2, 2, 2, 3], nrows=5)
81    rt4 = RaggedTensor.from_row_starts(values, row_starts=[0, 4, 4, 7, 8])
82    rt5 = RaggedTensor.from_row_limits(values, row_limits=[4, 4, 7, 8, 8])
83    for rt in (rt1, rt2, rt3, rt4, rt5):
84      self.assertAllEqual(rt, [[3, 1, 4, 1], [], [5, 9, 2], [6], []])
85    del rt1, rt2, rt3, rt4, rt5
86
87    # From section: "Multiple Ragged Dimensions"
88    inner_rt = RaggedTensor.from_row_splits(
89        values=[3, 1, 4, 1, 5, 9, 2, 6], row_splits=[0, 4, 4, 7, 8, 8])
90    outer_rt = RaggedTensor.from_row_splits(
91        values=inner_rt, row_splits=[0, 3, 3, 5])
92    self.assertEqual(outer_rt.ragged_rank, 2)
93    self.assertAllEqual(outer_rt,
94                        [[[3, 1, 4, 1], [], [5, 9, 2]], [], [[6], []]])
95    del inner_rt, outer_rt
96
97    # From section: "Multiple Ragged Dimensions"
98    rt = RaggedTensor.from_nested_row_splits(
99        flat_values=[3, 1, 4, 1, 5, 9, 2, 6],
100        nested_row_splits=([0, 3, 3, 5], [0, 4, 4, 7, 8, 8]))
101    self.assertAllEqual(rt, [[[3, 1, 4, 1], [], [5, 9, 2]], [], [[6], []]])
102    del rt
103
104    # From section: "Uniform Inner Dimensions"
105    rt = RaggedTensor.from_row_splits(
106        values=array_ops.ones([5, 3]), row_splits=[0, 2, 5])
107    self.assertAllEqual(
108        rt, [[[1, 1, 1], [1, 1, 1]], [[1, 1, 1], [1, 1, 1], [1, 1, 1]]])
109    self.assertEqual(rt.shape.as_list(), [2, None, 3])
110    del rt
111
112  #=============================================================================
113  # RaggedTensorValue Constructor
114  #=============================================================================
115
116  def testRaggedTensorValueConstruction(self):
117    values = np.array(b'a b c d e f g'.split())
118    splits = np.array([0, 2, 5, 6, 6, 7], dtype=np.int64)
119    splits2 = np.array([0, 3, 5], dtype=np.int64)
120
121    # Test construction of a RaggedTensorValue with ragged_rank=1.
122    rt_value = ragged_tensor_value.RaggedTensorValue(values, splits)
123    self.assertEqual(rt_value.row_splits.dtype, np.int64)
124    self.assertEqual(rt_value.shape, (5, None))
125    self.assertLen(rt_value.nested_row_splits, 1)
126    self.assertAllEqual(splits, rt_value.row_splits)
127    self.assertAllEqual(values, rt_value.values)
128    self.assertAllEqual(splits, rt_value.nested_row_splits[0])
129    self.assertAllEqual(values, rt_value.flat_values)
130
131    # Test construction of a RaggedTensorValue with ragged_rank=2.
132    rt_value = ragged_tensor_value.RaggedTensorValue(
133        values=ragged_tensor_value.RaggedTensorValue(values, splits),
134        row_splits=splits2)
135    self.assertEqual(rt_value.row_splits.dtype, np.int64)
136    self.assertEqual(rt_value.shape, (2, None, None))
137    self.assertLen(rt_value.nested_row_splits, 2)
138    self.assertAllEqual(splits2, rt_value.row_splits)
139    self.assertAllEqual(splits, rt_value.values.row_splits)
140    self.assertAllEqual(splits2, rt_value.nested_row_splits[0])
141    self.assertAllEqual(splits, rt_value.nested_row_splits[1])
142    self.assertAllEqual(values, rt_value.values.values)
143    self.assertAllEqual(values, rt_value.flat_values)
144
145  #=============================================================================
146  # RaggedTensor Constructor (private)
147  #=============================================================================
148
149  def testRaggedTensorConstruction(self):
150    values = constant_op.constant(['a', 'b', 'c', 'd', 'e', 'f', 'g'])
151    row_splits = constant_op.constant([0, 2, 2, 5, 6, 7], dtypes.int64)
152    rp = RowPartition.from_row_splits(row_splits)
153    rt = RaggedTensor(values=values, row_partition=rp, internal=True)
154
155    self.assertAllEqual(rt,
156                        [[b'a', b'b'], [], [b'c', b'd', b'e'], [b'f'], [b'g']])
157
158  def testRaggedTensorConstructionErrors(self):
159    values = constant_op.constant(['a', 'b', 'c', 'd', 'e', 'f', 'g'])
160    row_splits = constant_op.constant([0, 2, 2, 5, 6, 7], dtypes.int64)
161    rp = RowPartition.from_row_splits(row_splits)
162
163    with self.assertRaisesRegex(ValueError,
164                                'RaggedTensor constructor is private'):
165      RaggedTensor(values=values, row_partition=rp)
166
167    with self.assertRaisesRegex(
168        TypeError,
169        r"""type\(values\) must be one of: 'Tensor, RaggedTensor.*"""):
170      RaggedTensor(values=range(7), row_partition=rp, internal=True)
171
172    with self.assertRaisesRegex(TypeError,
173                                'row_partition must be a RowPartition'):
174      RaggedTensor(
175          values=values, row_partition=[0, 2, 2, 5, 6, 7], internal=True)
176
177  #=============================================================================
178  # RaggedTensor Factory Ops
179  #=============================================================================
180
181  def testFromValueRowIdsWithDerivedNRows(self):
182    # nrows is known at graph creation time.
183    values = constant_op.constant(['a', 'b', 'c', 'd', 'e', 'f', 'g'])
184    value_rowids = constant_op.constant([0, 0, 2, 2, 2, 3, 4], dtypes.int64)
185
186    rt = RaggedTensor.from_value_rowids(values, value_rowids, validate=False)
187    self.assertEqual(rt.dtype, dtypes.string)
188    self.assertEqual(rt.shape.as_list(), [5, None])
189    self.assertEqual(rt.ragged_rank, 1)
190
191    rt_values = rt.values
192    rt_value_rowids = rt.value_rowids()
193    rt_nrows = rt.nrows()
194
195    self.assertIs(rt_values, values)
196    self.assertIs(rt_value_rowids, value_rowids)  # cached_value_rowids
197    self.assertAllEqual(rt_value_rowids, value_rowids)
198    self.assertAllEqual(rt_nrows, 5)
199    self.assertAllEqual(rt,
200                        [[b'a', b'b'], [], [b'c', b'd', b'e'], [b'f'], [b'g']])
201
202  def testFromValueRowIdsWithDerivedNRowsDynamic(self):
203    # nrows is not known at graph creation time.
204    values = constant_op.constant(['a', 'b', 'c', 'd', 'e', 'f', 'g'])
205    value_rowids = constant_op.constant([0, 0, 2, 2, 2, 3, 4], dtypes.int64)
206    value_rowids = array_ops.placeholder_with_default(value_rowids, shape=None)
207
208    rt = RaggedTensor.from_value_rowids(values, value_rowids, validate=False)
209    self.assertEqual(rt.dtype, dtypes.string)
210    if context.executing_eagerly():
211      self.assertEqual(rt.shape.as_list(), [5, None])
212    else:
213      self.assertEqual(rt.shape.as_list(), [None, None])
214    self.assertEqual(rt.ragged_rank, 1)
215
216    rt_values = rt.values
217    rt_value_rowids = rt.value_rowids()
218    rt_nrows = rt.nrows()
219
220    self.assertIs(rt_values, values)
221    self.assertIs(rt_value_rowids, value_rowids)  # cached_value_rowids
222    self.assertAllEqual(rt_value_rowids, value_rowids)
223    self.assertAllEqual(rt_nrows, 5)
224    self.assertAllEqual(rt,
225                        [[b'a', b'b'], [], [b'c', b'd', b'e'], [b'f'], [b'g']])
226
227  def testFromValueRowIdsWithExplicitNRows(self):
228    values = constant_op.constant(['a', 'b', 'c', 'd', 'e', 'f', 'g'])
229    value_rowids = constant_op.constant([0, 0, 2, 2, 2, 3, 4], dtypes.int64)
230    nrows = constant_op.constant(7, dtypes.int64)
231
232    rt = RaggedTensor.from_value_rowids(
233        values, value_rowids, nrows, validate=False)
234    self.assertEqual(rt.dtype, dtypes.string)
235    self.assertEqual(rt.shape.as_list(), [7, None])
236    self.assertEqual(rt.ragged_rank, 1)
237
238    rt_values = rt.values
239    rt_value_rowids = rt.value_rowids()
240    rt_nrows = rt.nrows()
241
242    self.assertIs(rt_values, values)
243    self.assertIs(rt_value_rowids, value_rowids)  # cached_value_rowids
244    self.assertIs(rt_nrows, nrows)  # cached_nrows
245    self.assertAllEqual(
246        rt, [[b'a', b'b'], [], [b'c', b'd', b'e'], [b'f'], [b'g'], [], []])
247
248  def testFromValueRowIdsWithExplicitNRowsEqualToDefault(self):
249    values = constant_op.constant(['a', 'b', 'c', 'd', 'e', 'f', 'g'])
250    value_rowids = constant_op.constant([0, 0, 2, 2, 2, 3, 4], dtypes.int64)
251    nrows = constant_op.constant(5, dtypes.int64)
252
253    rt = RaggedTensor.from_value_rowids(
254        values, value_rowids, nrows, validate=False)
255    self.assertEqual(rt.dtype, dtypes.string)
256    self.assertEqual(rt.shape.as_list(), [5, None])
257    self.assertEqual(rt.ragged_rank, 1)
258
259    rt_values = rt.values
260    rt_value_rowids = rt.value_rowids()
261    rt_nrows = rt.nrows()
262
263    self.assertIs(rt_values, values)
264    self.assertIs(rt_value_rowids, value_rowids)  # cached_value_rowids
265    self.assertIs(rt_nrows, nrows)  # cached_nrows
266    self.assertAllEqual(rt_value_rowids, value_rowids)
267    self.assertAllEqual(rt_nrows, nrows)
268    self.assertAllEqual(rt,
269                        [[b'a', b'b'], [], [b'c', b'd', b'e'], [b'f'], [b'g']])
270
271  def testFromValueRowIdsWithEmptyValues(self):
272    rt = RaggedTensor.from_value_rowids([], [])
273    rt_nrows = rt.nrows()
274    self.assertEqual(rt.dtype, dtypes.float32)
275    self.assertEqual(rt.shape.as_list(), [0, None])
276    self.assertEqual(rt.ragged_rank, 1)
277    self.assertEqual(rt.values.shape.as_list(), [0])
278    self.assertEqual(rt.value_rowids().shape.as_list(), [0])
279    self.assertAllEqual(rt_nrows, 0)
280    self.assertAllEqual(rt, [])
281
282  def testFromRowSplits(self):
283    values = constant_op.constant(['a', 'b', 'c', 'd', 'e', 'f', 'g'])
284    row_splits = constant_op.constant([0, 2, 2, 5, 6, 7], dtypes.int64)
285
286    rt = RaggedTensor.from_row_splits(values, row_splits, validate=False)
287    self.assertEqual(rt.dtype, dtypes.string)
288    self.assertEqual(rt.shape.as_list(), [5, None])
289    self.assertEqual(rt.ragged_rank, 1)
290
291    rt_values = rt.values
292    rt_row_splits = rt.row_splits
293    rt_nrows = rt.nrows()
294
295    self.assertIs(rt_values, values)
296    self.assertIs(rt_row_splits, row_splits)
297    self.assertAllEqual(rt_nrows, 5)
298    self.assertAllEqual(rt,
299                        [[b'a', b'b'], [], [b'c', b'd', b'e'], [b'f'], [b'g']])
300
301  def testFromRowSplitsWithDifferentSplitTypes(self):
302    values = constant_op.constant(['a', 'b', 'c', 'd', 'e', 'f', 'g'])
303    splits1 = [0, 2, 2, 5, 6, 7]
304    splits2 = np.array([0, 2, 2, 5, 6, 7], np.int64)
305    splits3 = np.array([0, 2, 2, 5, 6, 7], np.int32)
306    splits4 = constant_op.constant([0, 2, 2, 5, 6, 7], dtypes.int64)
307    splits5 = constant_op.constant([0, 2, 2, 5, 6, 7], dtypes.int32)
308    rt1 = RaggedTensor.from_row_splits(values, splits1)
309    rt2 = RaggedTensor.from_row_splits(values, splits2)
310    rt3 = RaggedTensor.from_row_splits(values, splits3)
311    rt4 = RaggedTensor.from_row_splits(values, splits4)
312    rt5 = RaggedTensor.from_row_splits(values, splits5)
313    self.assertEqual(rt1.row_splits.dtype, dtypes.int64)
314    self.assertEqual(rt2.row_splits.dtype, dtypes.int64)
315    self.assertEqual(rt3.row_splits.dtype, dtypes.int32)
316    self.assertEqual(rt4.row_splits.dtype, dtypes.int64)
317    self.assertEqual(rt5.row_splits.dtype, dtypes.int32)
318
319  def testFromRowSplitsWithEmptySplits(self):
320    err_msg = 'row_splits tensor may not be empty'
321    with self.assertRaisesRegex(ValueError, err_msg):
322      RaggedTensor.from_row_splits([], [])
323
324  def testFromRowStarts(self):
325    values = constant_op.constant(['a', 'b', 'c', 'd', 'e', 'f', 'g'])
326    row_starts = constant_op.constant([0, 2, 2, 5, 6], dtypes.int64)
327
328    rt = RaggedTensor.from_row_starts(values, row_starts, validate=False)
329    self.assertEqual(rt.dtype, dtypes.string)
330    self.assertEqual(rt.shape.as_list(), [5, None])
331    self.assertEqual(rt.ragged_rank, 1)
332
333    rt_values = rt.values
334    rt_row_starts = rt.row_starts()
335    rt_nrows = rt.nrows()
336
337    self.assertIs(rt_values, values)
338    self.assertAllEqual(rt_nrows, 5)
339    self.assertAllEqual(rt_row_starts, row_starts)
340    self.assertAllEqual(rt,
341                        [[b'a', b'b'], [], [b'c', b'd', b'e'], [b'f'], [b'g']])
342
343  def testFromRowLimits(self):
344    values = constant_op.constant(['a', 'b', 'c', 'd', 'e', 'f', 'g'])
345    row_limits = constant_op.constant([2, 2, 5, 6, 7], dtypes.int64)
346
347    rt = RaggedTensor.from_row_limits(values, row_limits, validate=False)
348    self.assertEqual(rt.dtype, dtypes.string)
349    self.assertEqual(rt.shape.as_list(), [5, None])
350    self.assertEqual(rt.ragged_rank, 1)
351
352    rt_values = rt.values
353    rt_row_limits = rt.row_limits()
354    rt_nrows = rt.nrows()
355
356    self.assertIs(rt_values, values)
357    self.assertAllEqual(rt_nrows, 5)
358    self.assertAllEqual(rt_row_limits, row_limits)
359    self.assertAllEqual(rt,
360                        [[b'a', b'b'], [], [b'c', b'd', b'e'], [b'f'], [b'g']])
361
362  def testFromRowLengths(self):
363    values = constant_op.constant(['a', 'b', 'c', 'd', 'e', 'f', 'g'])
364    row_lengths = constant_op.constant([2, 0, 3, 1, 1], dtypes.int64)
365
366    rt = RaggedTensor.from_row_lengths(values, row_lengths, validate=False)
367    self.assertEqual(rt.dtype, dtypes.string)
368    self.assertEqual(rt.shape.as_list(), [5, None])
369    self.assertEqual(rt.ragged_rank, 1)
370
371    rt_values = rt.values
372    rt_row_lengths = rt.row_lengths()
373    rt_nrows = rt.nrows()
374
375    self.assertIs(rt_values, values)
376    self.assertIs(rt_row_lengths, row_lengths)  # cached_nrows
377    self.assertAllEqual(rt_nrows, 5)
378    self.assertAllEqual(rt_row_lengths, row_lengths)
379    self.assertAllEqual(rt,
380                        [[b'a', b'b'], [], [b'c', b'd', b'e'], [b'f'], [b'g']])
381
382  def testFromRowLengthsInt32(self):
383    rt = RaggedTensor.from_row_lengths([1, 2, 3, 4],
384                                       constant_op.constant([1, 0, 3],
385                                                            dtype=dtypes.int32))
386    rt2 = RaggedTensor.from_row_lengths(rt, [2, 1, 0])
387    self.assertAllEqual([2, 1, 0], rt2.row_lengths())
388
389  def testFromUniformRowLength(self):
390    values = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16]
391
392    a1 = RaggedTensor.from_uniform_row_length(values, 2)
393    a2 = RaggedTensor.from_uniform_row_length(values, 2, 8)
394    self.assertAllEqual(
395        a1,
396        [[1, 2], [3, 4], [5, 6], [7, 8], [9, 10], [11, 12], [13, 14], [15, 16]])
397    self.assertAllEqual(a1, a2)
398    self.assertEqual(a1.shape.as_list(), [8, 2])
399    self.assertEqual(a2.shape.as_list(), [8, 2])
400
401    b1 = RaggedTensor.from_uniform_row_length(a1, 2)
402    b2 = RaggedTensor.from_uniform_row_length(a1, 2, 4)
403    self.assertAllEqual(b1, [[[1, 2], [3, 4]], [[5, 6], [7, 8]],
404                             [[9, 10], [11, 12]], [[13, 14], [15, 16]]])
405    self.assertAllEqual(b1, b2)
406    self.assertEqual(b1.shape.as_list(), [4, 2, 2])
407    self.assertEqual(b2.shape.as_list(), [4, 2, 2])
408
409    c1 = RaggedTensor.from_uniform_row_length(b1, 2)
410    c2 = RaggedTensor.from_uniform_row_length(b1, 2, 2)
411    self.assertAllEqual(c1, [[[[1, 2], [3, 4]], [[5, 6], [7, 8]]],
412                             [[[9, 10], [11, 12]], [[13, 14], [15, 16]]]])
413    self.assertAllEqual(c1, c2)
414    self.assertEqual(c1.shape.as_list(), [2, 2, 2, 2])
415    self.assertEqual(c2.shape.as_list(), [2, 2, 2, 2])
416
417  def testFromUniformRowLengthWithEmptyValues(self):
418    empty_values = []
419    a = RaggedTensor.from_uniform_row_length(empty_values, 0, nrows=10)
420    self.assertEqual(a.shape.as_list(), [10, 0])
421
422    b = RaggedTensor.from_uniform_row_length(a, 2)
423    self.assertEqual(b.shape.as_list(), [5, 2, 0])
424
425    # Make sure we avoid divide-by-zero when finding nrows for nvals=rowlen=0.
426    c = RaggedTensor.from_uniform_row_length(empty_values, 0)
427    self.assertEqual(c.shape.as_list(), [0, 0])
428    d = RaggedTensor.from_uniform_row_length(empty_values, 0, nrows=0)
429    self.assertEqual(d.shape.as_list(), [0, 0])
430
431  def testFromUniformRowLengthWithPlaceholders(self):
432    ph_values = array_ops.placeholder_with_default([1, 2, 3, 4, 5, 6], [None])
433    ph_rowlen = array_ops.placeholder_with_default(3, None)
434    rt1 = RaggedTensor.from_uniform_row_length(ph_values, 3)
435    rt2 = RaggedTensor.from_uniform_row_length(ph_values, ph_rowlen)
436    rt3 = RaggedTensor.from_uniform_row_length([1, 2, 3, 4, 5, 6], ph_rowlen)
437    self.assertAllEqual(rt1, [[1, 2, 3], [4, 5, 6]])
438    self.assertAllEqual(rt2, [[1, 2, 3], [4, 5, 6]])
439    self.assertAllEqual(rt3, [[1, 2, 3], [4, 5, 6]])
440    if context.executing_eagerly():
441      self.assertEqual(rt1.shape.as_list(), [2, 3])
442      self.assertEqual(rt2.shape.as_list(), [2, 3])
443      self.assertEqual(rt3.shape.as_list(), [2, 3])
444    else:
445      self.assertEqual(rt1.shape.as_list(), [None, 3])
446      self.assertEqual(rt2.shape.as_list(), [None, None])
447      self.assertEqual(rt3.shape.as_list(), [None, None])
448
449    b = RaggedTensor.from_uniform_row_length(rt1, 2)
450    self.assertAllEqual(b, [[[1, 2, 3], [4, 5, 6]]])
451
452    # Make sure we avoid divide-by-zero when finding nrows for nvals=rowlen=0.
453    ph_empty_values = array_ops.placeholder_with_default(
454        array_ops.zeros([0], dtypes.int64), [None])
455    ph_zero = array_ops.placeholder_with_default(0, [])
456    c = RaggedTensor.from_uniform_row_length(ph_empty_values, ph_zero)
457    if context.executing_eagerly():
458      self.assertEqual(c.shape.as_list(), [0, 0])
459    else:
460      self.assertEqual(c.shape.as_list(), [None, None])
461
462  def testFromNestedValueRowIdsWithDerivedNRows(self):
463    values = constant_op.constant(['a', 'b', 'c', 'd', 'e', 'f', 'g'])
464    nested_value_rowids = [
465        constant_op.constant([0, 0, 1, 3, 3], dtypes.int64),
466        constant_op.constant([0, 0, 2, 2, 2, 3, 4], dtypes.int64)
467    ]
468
469    rt = RaggedTensor.from_nested_value_rowids(values, nested_value_rowids)
470    self.assertEqual(rt.dtype, dtypes.string)
471    self.assertEqual(rt.shape.as_list(), [4, None, None])
472    self.assertEqual(rt.ragged_rank, 2)
473
474    rt_values = rt.values
475    rt_value_rowids = rt.value_rowids()
476    rt_values_values = rt_values.values
477    rt_values_value_rowids = rt_values.value_rowids()
478
479    self.assertIs(rt_values_values, values)
480    self.assertAllEqual(rt_value_rowids, nested_value_rowids[0])
481    self.assertAllEqual(rt_values_value_rowids, nested_value_rowids[1])
482    self.assertAllEqual(
483        rt, [[[b'a', b'b'], []], [[b'c', b'd', b'e']], [], [[b'f'], [b'g']]])
484
485  def testFromNestedValueRowIdsWithExplicitNRows(self):
486    values = constant_op.constant(['a', 'b', 'c', 'd', 'e', 'f', 'g'])
487    nested_value_rowids = [
488        constant_op.constant([0, 0, 1, 3, 3, 3], dtypes.int64),
489        constant_op.constant([0, 0, 2, 2, 2, 3, 4], dtypes.int64)
490    ]
491    nrows = [
492        constant_op.constant(6, dtypes.int64),
493        constant_op.constant(6, dtypes.int64)
494    ]
495
496    rt = RaggedTensor.from_nested_value_rowids(values, nested_value_rowids,
497                                               nrows)
498    self.assertEqual(rt.dtype, dtypes.string)
499    self.assertEqual(rt.shape.as_list(), [6, None, None])
500    self.assertEqual(rt.ragged_rank, 2)
501
502    rt_values = rt.values
503    rt_value_rowids = rt.value_rowids()
504    rt_nrows = rt.nrows()
505    rt_values_values = rt_values.values
506    rt_values_value_rowids = rt_values.value_rowids()
507    rt_values_nrows = rt_values.nrows()
508
509    self.assertIs(rt_values_values, values)
510    self.assertAllEqual(rt_value_rowids, nested_value_rowids[0])
511    self.assertAllEqual(rt_values_value_rowids, nested_value_rowids[1])
512    self.assertAllEqual(rt_nrows, nrows[0])
513    self.assertAllEqual(rt_values_nrows, nrows[1])
514    self.assertAllEqual(rt, [[[b'a', b'b'], []], [[b'c', b'd', b'e']], [],
515                             [[b'f'], [b'g'], []], [], []])
516
517  def testFromNestedValueRowIdsWithExplicitNRowsMismatch(self):
518    values = constant_op.constant(['a', 'b', 'c', 'd', 'e', 'f', 'g'])
519    nested_value_rowids = [
520        constant_op.constant([0, 0, 1, 3, 3, 3], dtypes.int64),
521        constant_op.constant([0, 0, 2, 2, 2, 3, 4], dtypes.int64)
522    ]
523    nrows = [constant_op.constant(6, dtypes.int64)]
524    with self.assertRaisesRegex(
525        ValueError, 'nested_nrows must have the same '
526        'length as nested_value_rowids'):
527      RaggedTensor.from_nested_value_rowids(values, nested_value_rowids, nrows)
528
529  def testFromNestedValueRowIdsWithNonListInput(self):
530    with self.assertRaisesRegex(
531        TypeError, 'nested_value_rowids must be a list of Tensors'):
532      RaggedTensor.from_nested_value_rowids(
533          [1, 2, 3], constant_op.constant([[0, 1, 2], [0, 1, 2]], dtypes.int64))
534    with self.assertRaisesRegex(TypeError,
535                                'nested_nrows must be a list of Tensors'):
536      RaggedTensor.from_nested_value_rowids([1, 2, 3], [[0, 1, 2], [0, 1, 2]],
537                                            constant_op.constant([3, 3]))
538
539  def testFromNestedRowSplits(self):
540    flat_values = constant_op.constant(['a', 'b', 'c', 'd', 'e', 'f', 'g'])
541    nested_row_splits = [
542        constant_op.constant([0, 2, 3, 3, 5], dtypes.int64),
543        constant_op.constant([0, 2, 2, 5, 6, 7], dtypes.int64)
544    ]
545
546    rt = RaggedTensor.from_nested_row_splits(
547        flat_values, nested_row_splits, validate=False)
548    self.assertEqual(rt.dtype, dtypes.string)
549    self.assertEqual(rt.shape.as_list(), [4, None, None])
550    self.assertEqual(rt.ragged_rank, 2)
551
552    rt_values = rt.values
553    rt_row_splits = rt.row_splits
554    rt_values_values = rt_values.values
555    rt_values_row_splits = rt_values.row_splits
556
557    self.assertIs(rt_values_values, flat_values)
558    self.assertIs(rt_row_splits, nested_row_splits[0])
559    self.assertIs(rt_values_row_splits, nested_row_splits[1])
560    self.assertAllEqual(
561        rt, [[[b'a', b'b'], []], [[b'c', b'd', b'e']], [], [[b'f'], [b'g']]])
562
563  def testWithRowSplits(self):
564    flat_values = constant_op.constant(['a', 'b', 'c', 'd', 'e', 'f', 'g'])
565    nested_row_splits = [
566        constant_op.constant([0, 2, 3, 3, 5], dtypes.int64),
567        constant_op.constant([0, 2, 2, 5, 6, 7], dtypes.int64)
568    ]
569
570    rt = RaggedTensor.from_nested_row_splits(
571        flat_values, nested_row_splits, validate=False)
572
573    rt = rt.with_row_splits_dtype(dtypes.int32)
574
575    self.assertEqual(rt.dtype, dtypes.string)
576    self.assertEqual(rt.shape.as_list(), [4, None, None])
577    self.assertEqual(rt.ragged_rank, 2)
578
579    rt_values = rt.values
580    rt_row_splits = rt.row_splits
581    rt_values_values = rt_values.values
582    rt_values_row_splits = rt_values.row_splits
583
584    self.assertAllEqual(rt_values_values, flat_values)
585    self.assertAllEqual(rt_row_splits, nested_row_splits[0])
586    self.assertAllEqual(rt_values_row_splits, nested_row_splits[1])
587    self.assertAllEqual(
588        rt, [[[b'a', b'b'], []], [[b'c', b'd', b'e']], [], [[b'f'], [b'g']]])
589
590  def testFromNestedRowSplitsWithNonListInput(self):
591    with self.assertRaisesRegex(TypeError,
592                                'nested_row_splits must be a list of Tensors'):
593      RaggedTensor.from_nested_row_splits(
594          [1, 2], constant_op.constant([[0, 1, 2], [0, 1, 2]], dtypes.int64))
595
596  def testFromValueRowIdsWithBadNRows(self):
597    values = constant_op.constant(['a', 'b', 'c', 'd', 'e', 'f', 'g'])
598    value_rowids = constant_op.constant([0, 0, 2, 2, 2, 3, 4], dtypes.int64)
599    nrows = constant_op.constant(5, dtypes.int64)
600
601    with self.assertRaisesRegex(ValueError, r'Expected nrows >= 0; got -2'):
602      RaggedTensor.from_value_rowids(
603          values=values,
604          value_rowids=array_ops.placeholder_with_default(value_rowids, None),
605          nrows=-2)
606
607    with self.assertRaisesRegex(
608        ValueError, r'Expected nrows >= value_rowids\[-1\] \+ 1; got nrows=2, '
609        r'value_rowids\[-1\]=4'):
610      RaggedTensor.from_value_rowids(
611          values=values, value_rowids=value_rowids, nrows=2)
612
613    with self.assertRaisesRegex(
614        ValueError, r'Expected nrows >= value_rowids\[-1\] \+ 1; got nrows=4, '
615        r'value_rowids\[-1\]=4'):
616      RaggedTensor.from_value_rowids(
617          values=values, value_rowids=value_rowids, nrows=4)
618
619    with self.assertRaisesRegex(ValueError, r'Shape \(7, 1\) must have rank 1'):
620      RaggedTensor.from_value_rowids(
621          values=values,
622          value_rowids=array_ops.expand_dims(value_rowids, 1),
623          nrows=nrows)
624
625    with self.assertRaisesRegex(ValueError, r'Shape \(1,\) must have rank 0'):
626      RaggedTensor.from_value_rowids(
627          values=values,
628          value_rowids=value_rowids,
629          nrows=array_ops.expand_dims(nrows, 0))
630
631  def testCondWithTensorsFromValueIds(self):
632    # b/141166460
633    rt = RaggedTensor.from_value_rowids([1, 2, 3], [0, 0, 2])
634    c = array_ops.placeholder_with_default(True, None)
635    result = control_flow_ops.cond(c, lambda: rt, lambda: rt)
636    self.assertAllEqual(rt, result)
637
638  def testGraphMismatch(self):
639    if not context.executing_eagerly():
640      with ops.Graph().as_default():
641        values = constant_op.constant([1, 2, 3], dtypes.int64)
642      with ops.Graph().as_default():
643        splits = constant_op.constant([0, 2, 3], dtypes.int64)
644      with self.assertRaisesRegex(ValueError,
645                                  '.* must be from the same graph as .*'):
646        RaggedTensor.from_row_splits(values, splits)
647
648  #=============================================================================
649  # Ragged Value & Row-Partitioning Tensor Accessors
650  #=============================================================================
651
652  def testRaggedTensorAccessors_2d(self):
653    values = constant_op.constant(['a', 'b', 'c', 'd', 'e', 'f', 'g'])
654    row_splits = constant_op.constant([0, 2, 2, 5, 6, 7], dtypes.int64)
655    value_rowids = constant_op.constant([0, 0, 2, 2, 2, 3, 4], dtypes.int64)
656    rt1 = RaggedTensor.from_row_splits(values, row_splits)
657    rt2 = RaggedTensor.from_value_rowids(values, value_rowids)
658
659    for rt in [rt1, rt2]:
660      self.assertAllEqual(
661          rt, [[b'a', b'b'], [], [b'c', b'd', b'e'], [b'f'], [b'g']])
662      self.assertAllEqual(rt.values, [b'a', b'b', b'c', b'd', b'e', b'f', b'g'])
663      self.assertEqual(rt.values.shape.dims[0].value, 7)
664      self.assertAllEqual(rt.value_rowids(), [0, 0, 2, 2, 2, 3, 4])
665      self.assertAllEqual(rt.nrows(), 5)
666      self.assertAllEqual(rt.row_splits, [0, 2, 2, 5, 6, 7])
667      self.assertAllEqual(rt.row_starts(), [0, 2, 2, 5, 6])
668      self.assertAllEqual(rt.row_limits(), [2, 2, 5, 6, 7])
669      self.assertAllEqual(rt.row_lengths(), [2, 0, 3, 1, 1])
670      self.assertAllEqual(rt.flat_values,
671                          [b'a', b'b', b'c', b'd', b'e', b'f', b'g'])
672      self.assertLen(rt.nested_row_splits, 1)
673      self.assertAllEqual(rt.nested_row_splits[0], [0, 2, 2, 5, 6, 7])
674
675  def testRaggedTensorAccessors_3d_with_ragged_rank_1(self):
676    values = [[0, 1], [2, 3], [4, 5], [6, 7], [8, 9], [10, 11], [12, 13]]
677    row_splits = constant_op.constant([0, 2, 2, 5, 6, 7], dtypes.int64)
678    value_rowids = constant_op.constant([0, 0, 2, 2, 2, 3, 4], dtypes.int64)
679    row_lengths = constant_op.constant([2, 0, 3, 1, 1])
680    rt1 = RaggedTensor.from_row_splits(values, row_splits)
681    rt2 = RaggedTensor.from_value_rowids(values, value_rowids)
682    rt3 = RaggedTensor.from_row_lengths(values, row_lengths)
683
684    for rt in [rt1, rt2, rt3]:
685      self.assertAllEqual(rt, [[[0, 1], [2, 3]], [], [[4, 5], [6, 7], [8, 9]],
686                               [[10, 11]], [[12, 13]]])
687      self.assertAllEqual(
688          rt.values,
689          [[0, 1], [2, 3], [4, 5], [6, 7], [8, 9], [10, 11], [12, 13]])
690      self.assertEqual(rt.values.shape.dims[0].value, 7)
691      self.assertAllEqual(rt.value_rowids(), [0, 0, 2, 2, 2, 3, 4])
692      self.assertAllEqual(rt.nrows(), 5)
693      self.assertAllEqual(rt.row_splits, [0, 2, 2, 5, 6, 7])
694      self.assertAllEqual(rt.row_starts(), [0, 2, 2, 5, 6])
695      self.assertAllEqual(rt.row_limits(), [2, 2, 5, 6, 7])
696      self.assertAllEqual(rt.row_lengths(), [2, 0, 3, 1, 1])
697      self.assertAllEqual(rt.row_lengths(axis=2),
698                          [[2, 2], [], [2, 2, 2], [2], [2]])
699      self.assertAllEqual(
700          rt.flat_values,
701          [[0, 1], [2, 3], [4, 5], [6, 7], [8, 9], [10, 11], [12, 13]])
702      self.assertLen(rt.nested_row_splits, 1)
703      self.assertAllEqual(rt.nested_row_splits[0], [0, 2, 2, 5, 6, 7])
704      self.assertLen(rt.nested_value_rowids(), 1)
705
706      self.assertAllEqual(rt.nested_value_rowids()[0], [0, 0, 2, 2, 2, 3, 4])
707
708  def testRaggedTensorAccessors_3d_with_ragged_rank_2(self):
709    values = constant_op.constant(['a', 'b', 'c', 'd', 'e', 'f', 'g'])
710    nested_row_splits = [
711        constant_op.constant([0, 2, 3, 3, 5], dtypes.int64),
712        constant_op.constant([0, 2, 2, 5, 6, 7], dtypes.int64)
713    ]
714    nested_value_rowids = [
715        constant_op.constant([0, 0, 1, 3, 3], dtypes.int64),
716        constant_op.constant([0, 0, 2, 2, 2, 3, 4], dtypes.int64)
717    ]
718    rt1 = RaggedTensor.from_nested_row_splits(values, nested_row_splits)
719    rt2 = RaggedTensor.from_nested_value_rowids(values, nested_value_rowids)
720
721    for rt in [rt1, rt2]:
722      self.assertAllEqual(
723          rt, [[[b'a', b'b'], []], [[b'c', b'd', b'e']], [], [[b'f'], [b'g']]])
724      self.assertAllEqual(
725          rt.values, [[b'a', b'b'], [], [b'c', b'd', b'e'], [b'f'], [b'g']])
726      self.assertEqual(rt.values.shape.dims[0].value, 5)
727      self.assertAllEqual(rt.value_rowids(), [0, 0, 1, 3, 3])
728      self.assertAllEqual(rt.nrows(), 4)
729      self.assertAllEqual(rt.row_splits, [0, 2, 3, 3, 5])
730      self.assertAllEqual(rt.row_starts(), [0, 2, 3, 3])
731      self.assertAllEqual(rt.row_limits(), [2, 3, 3, 5])
732      self.assertAllEqual(rt.row_lengths(), [2, 1, 0, 2])
733      self.assertAllEqual(rt.flat_values,
734                          [b'a', b'b', b'c', b'd', b'e', b'f', b'g'])
735      self.assertLen(rt.nested_row_splits, 2)
736      self.assertAllEqual(rt.nested_row_splits[0], [0, 2, 3, 3, 5])
737      self.assertAllEqual(rt.nested_row_splits[1], [0, 2, 2, 5, 6, 7])
738      self.assertLen(rt.nested_value_rowids(), 2)
739      self.assertAllEqual(rt.nested_value_rowids()[0], [0, 0, 1, 3, 3])
740      self.assertAllEqual(rt.nested_value_rowids()[1], [0, 0, 2, 2, 2, 3, 4])
741
742  #=============================================================================
743  # RaggedTensor.shape
744  #=============================================================================
745
746  def testShape(self):
747    """Tests for RaggedTensor.shape."""
748    rt1 = RaggedTensor.from_row_splits(b'a b c d e f g'.split(),
749                                       [0, 2, 5, 6, 6, 7])
750    self.assertEqual(rt1.shape.as_list(), [5, None])
751
752    rt2 = RaggedTensor.from_row_splits(
753        [[1, 2], [3, 4], [5, 6], [7, 8], [9, 10], [11, 12], [13, 14]],
754        [0, 2, 5, 6, 6, 7])
755    self.assertEqual(rt2.shape.as_list(), [5, None, 2])
756
757    rt3 = RaggedTensor.from_row_splits(
758        [[[1, 2], [3, 4]], [[5, 6], [7, 8]], [[9, 10], [11, 12]]], [0, 2, 2, 3])
759    self.assertEqual(rt3.shape.as_list(), [3, None, 2, 2])
760
761    rt4 = RaggedTensor.from_row_splits(rt3, [0, 1, 3, 3])
762    self.assertEqual(rt4.shape.as_list(), [3, None, None, 2, 2])
763
764    if not context.executing_eagerly():
765      rt5 = RaggedTensor.from_row_splits(
766          array_ops.placeholder(dtype=dtypes.string), [0, 2, 3, 5])
767      self.assertIsNone(rt5.shape.ndims)
768
769      rt6 = RaggedTensor.from_row_splits(
770          [1, 2, 3], array_ops.placeholder(dtype=dtypes.int64))
771      self.assertEqual(rt6.shape.as_list(), [None, None])
772
773  def testGetShape(self):
774    rt = RaggedTensor.from_row_splits(b'a b c d e f g'.split(),
775                                      [0, 2, 5, 6, 6, 7])
776    self.assertEqual(rt.shape.as_list(), rt.get_shape().as_list())
777
778  #=============================================================================
779  # RaggedTensor.__str__
780  #=============================================================================
781  def testRaggedTensorStr(self):
782    values = [b'a', b'b', b'c', b'd', b'e', b'f', b'g']
783    row_splits = [0, 2, 5, 6, 6, 7]
784    rt = RaggedTensor.from_row_splits(values, row_splits, validate=False)
785    splits_type = 'int64'
786    if context.executing_eagerly():
787      expected_repr = '<tf.RaggedTensor {}>'.format([[b'a', b'b'],
788                                                     [b'c', b'd', b'e'], [b'f'],
789                                                     [], [b'g']])
790    else:
791      expected_repr = (
792          'tf.RaggedTensor(values=Tensor("RaggedFromRowSplits/values:0", '
793          'shape=(7,), dtype=string), '
794          'row_splits=Tensor('
795          '"RaggedFromRowSplits/RowPartitionFromRowSplits/row_splits:0",'
796          ' shape=(6,), dtype={}))').format(splits_type)
797    self.assertEqual(repr(rt), expected_repr)
798    self.assertEqual(str(rt), expected_repr)
799
800  def testRaggedTensorValueStr(self):
801    values = [b'a', b'b', b'c', b'd', b'e', b'f', b'g']
802    row_splits = [0, 2, 5, 6, 6, 7]
803    rt = ragged_tensor_value.RaggedTensorValue(
804        np.array(values), np.array(row_splits, dtype=np.int64))
805    expected_str = '<tf.RaggedTensorValue {}>'.format([[b'a', b'b'],
806                                                       [b'c', b'd', b'e'],
807                                                       [b'f'], [], [b'g']])
808    expected_repr = ("tf.RaggedTensorValue(values=array({}, dtype='|S1'), "
809                     'row_splits=array({}))'.format(values, row_splits))
810    self.assertEqual(' '.join(str(rt).split()), expected_str)
811    self.assertEqual(' '.join(repr(rt).split()), expected_repr)
812
813  #=============================================================================
814  # RaggedTensor.with_values() and RaggedTensor.with_flat_values().
815  #=============================================================================
816
817  def testWithValues(self):
818    rt1 = ragged_factory_ops.constant([[1, 2], [3, 4, 5], [6], [], [7]])
819    rt2 = ragged_factory_ops.constant([[[1, 2], [3, 4, 5]], [[6]], [], [[],
820                                                                        [7]]])
821
822    rt1_plus_10 = rt1.with_values(rt1.values + 10)
823    rt2_times_10 = rt2.with_flat_values(rt2.flat_values * 10)
824    rt1_expanded = rt1.with_values(array_ops.expand_dims(rt1.values, axis=1))
825
826    self.assertAllEqual(rt1_plus_10, [[11, 12], [13, 14, 15], [16], [], [17]])
827    self.assertAllEqual(rt2_times_10,
828                        [[[10, 20], [30, 40, 50]], [[60]], [], [[], [70]]])
829    self.assertAllEqual(rt1_expanded,
830                        [[[1], [2]], [[3], [4], [5]], [[6]], [], [[7]]])
831
832  #=============================================================================
833  # Session.run
834  #=============================================================================
835  def testSessionRun(self):
836    if context.executing_eagerly():
837      return
838
839    rt1 = ragged_factory_ops.constant([[1, 2, 3], [4]])
840    rt2 = ragged_factory_ops.constant([[[], [1, 2]], [[3]]])
841    with self.test_session() as session:
842      result = session.run({'rt1': rt1, 'rt2': rt2})
843      self.assertCountEqual(result.keys(), ['rt1', 'rt2'])
844      self.assertEqual(result['rt1'].to_list(), [[1, 2, 3], [4]])
845      self.assertEqual(result['rt2'].to_list(), [[[], [1, 2]], [[3]]])
846
847  def testSessionRunFeed(self):
848    if context.executing_eagerly():
849      return
850
851    rt1 = RaggedTensor.from_row_splits(
852        array_ops.placeholder(dtypes.int32),
853        array_ops.placeholder(dtypes.int64))
854    rt2 = RaggedTensor.from_nested_row_splits(
855        array_ops.placeholder(dtypes.int32), [
856            array_ops.placeholder(dtypes.int64),
857            array_ops.placeholder(dtypes.int64)
858        ])
859
860    rt1_feed_val = ragged_factory_ops.constant_value([[1, 2, 3], [4]])
861    rt2_feed_val = ragged_factory_ops.constant_value([[[], [1, 2]], [[3]]])
862
863    with self.test_session() as session:
864      fetches = {'rt1': rt1, 'rt2': rt2}
865      feeds = {rt1: rt1_feed_val, rt2: rt2_feed_val}
866      result = session.run(fetches, feed_dict=feeds)
867      self.assertCountEqual(result.keys(), ['rt1', 'rt2'])
868      self.assertEqual(result['rt1'].to_list(), [[1, 2, 3], [4]])
869      self.assertEqual(result['rt2'].to_list(), [[[], [1, 2]], [[3]]])
870
871  def testSessionPartialRunFeed(self):
872    if context.executing_eagerly():
873      return
874
875    # Placeholder inputs.
876    a = RaggedTensor.from_row_splits(
877        array_ops.placeholder(dtypes.int32, shape=[None], name='a.values'),
878        array_ops.placeholder(dtypes.int64, name='a.row_splits'))
879    b = RaggedTensor.from_row_splits(
880        array_ops.placeholder(dtypes.int32, shape=[None], name='b.values'),
881        array_ops.placeholder(dtypes.int64, name='b.row_splits'))
882    c = array_ops.placeholder(dtypes.int32, shape=[], name='c')
883
884    # Feed values for placeholder inputs.
885    a_val = ragged_factory_ops.constant_value([[1, 2, 3], [4]])
886    b_val = ragged_factory_ops.constant_value([[5, 4, 3], [2]])
887    c_val = 3
888
889    # Compute some values.
890    r1 = ragged_math_ops.reduce_sum(a * b, axis=1)
891    r2 = ragged_math_ops.reduce_sum(a + c, axis=1)
892
893    with self.test_session() as session:
894      handle = session.partial_run_setup([r1, r2], [a, b, c])
895
896      res1 = session.partial_run(handle, r1, feed_dict={a: a_val, b: b_val})
897      self.assertAllEqual(res1, [22, 8])
898
899      res2 = session.partial_run(handle, r2, feed_dict={c: c_val})
900      self.assertAllEqual(res2, [15, 7])
901
902  # Test case for GitHub issue 24679.
903  def testEagerForLoop(self):
904    if not context.executing_eagerly():
905      return
906
907    values = [[1., 2.], [3., 4., 5.], [6.]]
908    r = ragged_factory_ops.constant(values)
909    i = 0
910    for elem in r:
911      self.assertAllEqual(elem, values[i])
912      i += 1
913
914  def testConsumers(self):
915    if context.executing_eagerly():
916      return
917
918    a = RaggedTensor.from_row_splits(
919        array_ops.placeholder(dtypes.int32, shape=[None], name='a.values'),
920        array_ops.placeholder(dtypes.int64, name='a.row_splits'),
921        validate=False)
922    ragged_math_ops.reduce_sum(a)
923    self.assertLen(a.consumers(), 1)
924
925  @parameterized.parameters([
926      {
927          'descr': 'from_value_rowids',
928          'factory': RaggedTensor.from_value_rowids,
929          'test': RaggedTensor.value_rowids,
930          'values': {
931              'values': [1, 2, 3, 4, 5, 6],
932              'value_rowids': [0, 0, 1, 1, 2, 2],
933          },
934          'tensor_field': 'value_rowids',
935          'value_rowids': [0, 1, 2],
936          'nrows': 10
937      },
938      {
939          'descr': 'from_row_splits',
940          'factory': RaggedTensor.from_row_splits,
941          # row_splits is a property, not a function.
942          'test': (lambda rt: rt.row_splits),
943          'values': {
944              'values': [1, 2, 3, 4, 5, 6],
945              'row_splits': [0, 2, 4, 6],
946          },
947          'tensor_field': 'row_splits',
948          'row_splits': [0, 1, 2, 3]
949      },
950      {
951          'descr': 'from_row_lengths',
952          'factory': RaggedTensor.from_row_lengths,
953          'test': RaggedTensor.row_lengths,
954          'values': {
955              'values': [1, 2, 3, 4, 5, 6],
956              'row_lengths': [2, 2, 2],
957          },
958          'tensor_field': 'row_lengths',
959          'row_lengths': [1, 1, 1],
960      },
961      # from_row_starts
962      {
963          'descr': 'from_row_starts',
964          'factory': RaggedTensor.from_row_starts,
965          'test': RaggedTensor.row_starts,
966          'values': {
967              'values': [1, 2, 3, 4, 5, 6],
968              'row_starts': [0, 2, 4]
969          },
970          'tensor_field': 'row_starts',
971          'row_starts': [0, 1, 2]
972      },
973      # from_row_limits
974      {
975          'descr': 'from_row_limits',
976          'factory': RaggedTensor.from_row_limits,
977          'test': RaggedTensor.row_limits,
978          'values': {
979              'values': [1, 2, 3, 4, 5, 6],
980              'row_limits': [2, 4, 6]
981          },
982          'tensor_field': 'row_limits',
983          'row_limits': [3]
984      },
985      # from_uniform_row_length
986      {
987          'descr': 'from_uniform_row_length',
988          'factory': RaggedTensor.from_uniform_row_length,
989          # One cannot extract uniform_row_length or nvals, so we return
990          # nvals//nrows = uniform_row_length, where nvals = 3
991          'test': (lambda rt: 3 // (rt.shape[0])),
992          'values': {
993              'values': [1, 2, 3, 4, 5, 6],
994              'uniform_row_length': 2
995          },
996          'tensor_field': 'uniform_row_length',
997          'uniform_row_length': 3
998      },
999  ])
1000  def testFactoryTypePreference(self, descr, test, factory, values,
1001                                tensor_field, **kwargs):
1002    # When input tensors have shape information, some of these errors will be
1003    # detected statically.
1004    def op_cast(k, v):
1005      if k == tensor_field:
1006        return constant_op.constant(v, dtype=dtypes.int32)
1007      else:
1008        return v
1009
1010    value_copy = {k: op_cast(k, v) for k, v in values.items()}
1011    rt = factory(**value_copy)
1012
1013    kw_copy = {k: v for k, v in kwargs.items()}
1014    kw_copy['values'] = rt
1015    rt2 = factory(**kw_copy)
1016    self.assertAllEqual(kwargs[tensor_field], test(rt2))
1017
1018  @parameterized.parameters([
1019      # from_value_rowids
1020      {
1021          'descr': 'bad rank for value_rowids',
1022          'factory': RaggedTensor.from_value_rowids,
1023          'values': [[1, 2], [3, 4]],
1024          'value_rowids': [[1, 2], [3, 4]],
1025          'nrows': 10
1026      },
1027      {
1028          'descr': 'bad rank for nrows',
1029          'factory': RaggedTensor.from_value_rowids,
1030          'values': [1, 2, 3, 4],
1031          'value_rowids': [1, 2, 3, 4],
1032          'nrows': [10]
1033      },
1034      {
1035          'descr': 'len(values) != len(value_rowids)',
1036          'factory': RaggedTensor.from_value_rowids,
1037          'values': [1, 2, 3, 4],
1038          'value_rowids': [1, 2, 3, 4, 5],
1039          'nrows': 10
1040      },
1041      {
1042          'descr': 'negative value_rowid',
1043          'factory': RaggedTensor.from_value_rowids,
1044          'values': [1, 2, 3, 4],
1045          'value_rowids': [-5, 2, 3, 4],
1046          'nrows': 10
1047      },
1048      {
1049          'descr': 'non-monotonic-increasing value_rowid',
1050          'factory': RaggedTensor.from_value_rowids,
1051          'values': [1, 2, 3, 4],
1052          'value_rowids': [4, 3, 2, 1],
1053          'nrows': 10
1054      },
1055      {
1056          'descr': 'value_rowid > nrows',
1057          'factory': RaggedTensor.from_value_rowids,
1058          'values': [1, 2, 3, 4],
1059          'value_rowids': [1, 2, 3, 4],
1060          'nrows': 2
1061      },
1062      {
1063          'descr': 'bad rank for values',
1064          'factory': RaggedTensor.from_value_rowids,
1065          'values': 10,
1066          'value_rowids': [1, 2, 3, 4],
1067          'nrows': 10
1068      },
1069
1070      # from_row_splits
1071      {
1072          'descr': 'bad rank for row_splits',
1073          'factory': RaggedTensor.from_row_splits,
1074          'values': [[1, 2], [3, 4]],
1075          'row_splits': [[1, 2], [3, 4]]
1076      },
1077      {
1078          'descr': 'row_splits[0] != 0',
1079          'factory': RaggedTensor.from_row_splits,
1080          'values': [1, 2, 3, 4],
1081          'row_splits': [2, 3, 4]
1082      },
1083      {
1084          'descr': 'non-monotonic-increasing row_splits',
1085          'factory': RaggedTensor.from_row_splits,
1086          'values': [1, 2, 3, 4],
1087          'row_splits': [0, 3, 2, 4]
1088      },
1089      {
1090          'descr': 'row_splits[0] != nvals',
1091          'factory': RaggedTensor.from_row_splits,
1092          'values': [1, 2, 3, 4],
1093          'row_splits': [0, 2, 3, 5]
1094      },
1095      {
1096          'descr': 'bad rank for values',
1097          'factory': RaggedTensor.from_row_splits,
1098          'values': 10,
1099          'row_splits': [0, 1]
1100      },
1101
1102      # from_row_lengths
1103      {
1104          'descr': 'bad rank for row_lengths',
1105          'factory': RaggedTensor.from_row_lengths,
1106          'values': [1, 2, 3, 4],
1107          'row_lengths': [[1, 2], [1, 0]]
1108      },
1109      {
1110          'descr': 'negatve row_lengths',
1111          'factory': RaggedTensor.from_row_lengths,
1112          'values': [1, 2, 3, 4],
1113          'row_lengths': [3, -1, 2]
1114      },
1115      {
1116          'descr': 'sum(row_lengths) != nvals',
1117          'factory': RaggedTensor.from_row_lengths,
1118          'values': [1, 2, 3, 4],
1119          'row_lengths': [2, 4, 2, 8]
1120      },
1121      {
1122          'descr': 'bad rank for values',
1123          'factory': RaggedTensor.from_row_lengths,
1124          'values': 10,
1125          'row_lengths': [0, 1]
1126      },
1127
1128      # from_row_starts
1129      {
1130          'descr': 'bad rank for row_starts',
1131          'factory': RaggedTensor.from_row_starts,
1132          'values': [[1, 2], [3, 4]],
1133          'row_starts': [[1, 2], [3, 4]]
1134      },
1135      {
1136          'descr': 'row_starts[0] != 0',
1137          'factory': RaggedTensor.from_row_starts,
1138          'values': [1, 2, 3, 4],
1139          'row_starts': [2, 3, 4]
1140      },
1141      {
1142          'descr': 'non-monotonic-increasing row_starts',
1143          'factory': RaggedTensor.from_row_starts,
1144          'values': [1, 2, 3, 4],
1145          'row_starts': [0, 3, 2, 4]
1146      },
1147      {
1148          'descr': 'row_starts[0] > nvals',
1149          'factory': RaggedTensor.from_row_starts,
1150          'values': [1, 2, 3, 4],
1151          'row_starts': [0, 2, 3, 5]
1152      },
1153      {
1154          'descr': 'bad rank for values',
1155          'factory': RaggedTensor.from_row_starts,
1156          'values': 10,
1157          'row_starts': [0, 1]
1158      },
1159
1160      # from_row_limits
1161      {
1162          'descr': 'bad rank for row_limits',
1163          'factory': RaggedTensor.from_row_limits,
1164          'values': [[1, 2], [3, 4]],
1165          'row_limits': [[1, 2], [3, 4]]
1166      },
1167      {
1168          'descr': 'row_limits[0] < 0',
1169          'factory': RaggedTensor.from_row_limits,
1170          'values': [1, 2, 3, 4],
1171          'row_limits': [-1, 3, 4]
1172      },
1173      {
1174          'descr': 'non-monotonic-increasing row_limits',
1175          'factory': RaggedTensor.from_row_limits,
1176          'values': [1, 2, 3, 4],
1177          'row_limits': [0, 3, 2, 4]
1178      },
1179      {
1180          'descr': 'row_limits[0] != nvals',
1181          'factory': RaggedTensor.from_row_limits,
1182          'values': [1, 2, 3, 4],
1183          'row_limits': [0, 2, 3, 5]
1184      },
1185      {
1186          'descr': 'bad rank for values',
1187          'factory': RaggedTensor.from_row_limits,
1188          'values': 10,
1189          'row_limits': [0, 1]
1190      },
1191
1192      # from_uniform_row_length
1193      {
1194          'descr': 'rowlen * nrows != nvals (1)',
1195          'factory': RaggedTensor.from_uniform_row_length,
1196          'values': [1, 2, 3, 4, 5],
1197          'uniform_row_length': 3
1198      },
1199      {
1200          'descr': 'rowlen * nrows != nvals (2)',
1201          'factory': RaggedTensor.from_uniform_row_length,
1202          'values': [1, 2, 3, 4, 5],
1203          'uniform_row_length': 6
1204      },
1205      {
1206          'descr': 'rowlen * nrows != nvals (3)',
1207          'factory': RaggedTensor.from_uniform_row_length,
1208          'values': [1, 2, 3, 4, 5, 6],
1209          'uniform_row_length': 3,
1210          'nrows': 3
1211      },
1212      {
1213          'descr': 'rowlen must be a scalar',
1214          'factory': RaggedTensor.from_uniform_row_length,
1215          'values': [1, 2, 3, 4],
1216          'uniform_row_length': [2]
1217      },
1218      {
1219          'descr': 'rowlen must be nonnegative',
1220          'factory': RaggedTensor.from_uniform_row_length,
1221          'values': [1, 2, 3, 4],
1222          'uniform_row_length': -1
1223      },
1224  ])
1225  def testFactoryValidation(self, descr, factory, **kwargs):
1226    # When input tensors have shape information, some of these errors will be
1227    # detected statically.
1228    with self.assertRaises((errors.InvalidArgumentError, ValueError)):
1229      self.evaluate(factory(**kwargs))
1230
1231    # Remove shape information (by wrapping tensors in placeholders), and check
1232    # that we detect the errors when the graph is run.
1233    if not context.executing_eagerly():
1234
1235      def wrap_arg(v):
1236        return array_ops.placeholder_with_default(
1237            constant_op.constant(v, dtype=dtypes.int64),
1238            tensor_shape.TensorShape(None))
1239
1240      kwargs = dict((k, wrap_arg(v)) for (k, v) in kwargs.items())
1241
1242      with self.assertRaises(errors.InvalidArgumentError):
1243        self.evaluate(factory(**kwargs))
1244
1245  #=============================================================================
1246  # RaggedTensor Variant conversion
1247  #=============================================================================
1248
1249  @parameterized.named_parameters(
1250      {
1251          'testcase_name': 'Shape_5_none',
1252          'ragged_constant': [[1, 2], [3, 4, 5], [6], [], [7]],
1253          'ragged_rank': 1
1254      }, {
1255          'testcase_name': 'Shape_4_none_2',
1256          'ragged_constant': [[[1, 2]], [], [[3, 4]], []],
1257          'ragged_rank': 1
1258      }, {
1259          'testcase_name': 'Shape_1_none_none',
1260          'ragged_constant': [[[1], [2, 3, 4, 5, 6, 7]], [[]]],
1261          'ragged_rank': 2
1262      })
1263  def testRaggedToVariant(self, ragged_constant, ragged_rank):
1264    rt = ragged_factory_ops.constant(ragged_constant, ragged_rank=ragged_rank)
1265    et = rt._to_variant()
1266    self.assertEqual(et.shape.as_list(), [])
1267    self.assertEqual(et.dtype, dtypes.variant)
1268
1269  @parameterized.parameters(
1270      {
1271          'ragged_constant': [[1, 2], [3, 4, 5], [6], [], [7]],
1272          'ragged_rank': 1,
1273          'num_batched_elems': 5
1274      }, {
1275          'ragged_constant': [[[1, 2]], [], [[3, 4]], []],
1276          'ragged_rank': 1,
1277          'num_batched_elems': 4
1278      }, {
1279          'ragged_constant': [[[1], [2, 3, 4, 5, 6, 7]], [[]]],
1280          'ragged_rank': 2,
1281          'num_batched_elems': 2
1282      })
1283  def testRaggedToBatchedVariant(self, ragged_constant, ragged_rank,
1284                                 num_batched_elems):
1285    rt = ragged_factory_ops.constant(ragged_constant, ragged_rank=ragged_rank)
1286    et = rt._to_variant(batched_input=True)
1287    self.assertEqual(et.shape.as_list(), [num_batched_elems])
1288    self.assertEqual(et.dtype, dtypes.variant)
1289
1290  @parameterized.parameters(
1291      # 2D test cases.
1292      {
1293          'ragged_constant': [[]],
1294          'ragged_rank': 1,
1295      },
1296      {
1297          'ragged_constant': [[1]],
1298          'ragged_rank': 1,
1299      },
1300      {
1301          'ragged_constant': [[1, 2]],
1302          'ragged_rank': 1,
1303      },
1304      {
1305          'ragged_constant': [[1], [2], [3]],
1306          'ragged_rank': 1,
1307      },
1308      {
1309          'ragged_constant': [[1, 2, 3], [4, 5, 6], [7, 8, 9]],
1310          'ragged_rank': 1,
1311      },
1312      {
1313          'ragged_constant': [[1, 2], [3, 4, 5], [6], [], [7]],
1314          'ragged_rank': 1,
1315      },
1316      # 3D test cases.
1317      {
1318          'ragged_constant': [[[]]],
1319          'ragged_rank': 2,
1320      },
1321      {
1322          'ragged_constant': [[[1]]],
1323          'ragged_rank': 2,
1324      },
1325      {
1326          'ragged_constant': [[[1, 2]]],
1327          'ragged_rank': 2,
1328      },
1329      {
1330          'ragged_constant': [[[1, 2], [3, 4]]],
1331          'ragged_rank': 2,
1332      },
1333      {
1334          'ragged_constant': [[[1, 2]], [[3, 4]], [[5, 6]], [[7, 8]]],
1335          'ragged_rank': 2,
1336      },
1337      {
1338          'ragged_constant': [[[1], [2]], [[3], [4]], [[5], [6]], [[7], [8]]],
1339          'ragged_rank': 2,
1340      },
1341      {
1342          'ragged_constant': [[[1, 2]], [], [[3, 4]], []],
1343          'ragged_rank': 2,
1344      },
1345      # 4D test cases.
1346      {
1347          'ragged_constant': [[[[1, 2], [3, 4]]],
1348                              [[[0, 0], [0, 0]], [[5, 6], [7, 8]]], []],
1349          'ragged_rank': 3,
1350      },
1351      # dtype `string`.
1352      {
1353          'ragged_constant': [['a'], ['b'], ['c']],
1354          'ragged_rank': 1,
1355          'dtype': dtypes.string,
1356      },
1357      {
1358          'ragged_constant': [[['a', 'b'], ['c', 'd']]],
1359          'ragged_rank': 2,
1360          'dtype': dtypes.string,
1361      },
1362      {
1363          'ragged_constant': [[[['a', 'b'], ['c', 'd']]],
1364                              [[['e', 'f'], ['g', 'h']], [['i', 'j'],
1365                                                          ['k', 'l']]], []],
1366          'ragged_rank': 3,
1367          'dtype': dtypes.string,
1368      })
1369  def testVariantRoundTrip(self,
1370                           ragged_constant,
1371                           ragged_rank,
1372                           dtype=dtypes.int32):
1373    rt = ragged_factory_ops.constant(
1374        ragged_constant, ragged_rank=ragged_rank, dtype=dtype)
1375    et = rt._to_variant()
1376    round_trip_rt = RaggedTensor._from_variant(
1377        et, dtype, output_ragged_rank=ragged_rank)
1378    self.assertAllEqual(rt, round_trip_rt)
1379
1380  def testBatchedVariantRoundTripInputRaggedRankInferred(self):
1381    ragged_rank = 1
1382    rt = ragged_factory_ops.constant(
1383        [[0], [1], [2], [3], [4], [5], [6], [7], [8], [9]],
1384        ragged_rank=ragged_rank)
1385    batched_variant = rt._to_variant(batched_input=True)
1386    nested_batched_variant = array_ops.reshape(batched_variant, [5, 2])
1387    decoded_rt = RaggedTensor._from_variant(
1388        nested_batched_variant,
1389        dtype=dtypes.int32,
1390        output_ragged_rank=ragged_rank + 1)
1391    expected_rt = ragged_factory_ops.constant([[[0], [1]], [[2], [3]], [[4],
1392                                                                        [5]],
1393                                               [[6], [7]], [[8], [9]]])
1394    self.assertAllEqual(decoded_rt, expected_rt)
1395
1396  def testBatchedVariantRoundTripWithInputRaggedRank(self):
1397    ragged_rank = 1
1398    rt = ragged_factory_ops.constant(
1399        [[0], [1], [2], [3], [4], [5], [6], [7], [8], [9]],
1400        ragged_rank=ragged_rank)
1401    batched_variant = rt._to_variant(batched_input=True)
1402    nested_batched_variant = array_ops.reshape(batched_variant, [5, 2])
1403    decoded_rt = RaggedTensor._from_variant(
1404        nested_batched_variant,
1405        dtype=dtypes.int32,
1406        output_ragged_rank=ragged_rank + 1,
1407        input_ragged_rank=ragged_rank - 1)
1408    expected_rt = ragged_factory_ops.constant([[[0], [1]], [[2], [3]], [[4],
1409                                                                        [5]],
1410                                               [[6], [7]], [[8], [9]]])
1411    self.assertAllEqual(decoded_rt, expected_rt)
1412
1413  def testUnbatchVariant(self):  # b/141789000
1414    rt = ragged_factory_ops.constant([[1, 2, 3], [4, 5], [], [6, 7, 8, 9]])
1415    batched = rt._to_variant(batched_input=True)
1416    for i in range(4):
1417      row = RaggedTensor._from_variant(
1418          batched[i], dtype=dtypes.int32, output_ragged_rank=0)
1419      self.assertAllEqual(rt[i], row)
1420
1421  def testUnbatchVariantInDataset(self):
1422    rt = ragged_factory_ops.constant([[1, 2, 3], [4, 5], [], [6, 7, 8, 9]])
1423    ds = dataset_ops.Dataset.from_tensor_slices(rt)
1424    if context.executing_eagerly():
1425      for i, value in enumerate(ds):
1426        self.assertAllEqual(rt[i], value)
1427    else:
1428      it = dataset_ops.make_one_shot_iterator(ds)
1429      out = it.get_next()
1430      with self.cached_session() as sess:
1431        for i in range(3):
1432          self.assertAllEqual(sess.run(rt[i]), out)
1433
1434  def testFromVariantInvalidParams(self):
1435    rt = ragged_factory_ops.constant([[0], [1], [2], [3]])
1436    batched_variant = rt._to_variant(batched_input=True)
1437    nested_batched_variant = array_ops.reshape(batched_variant, [2, 2])
1438    with self.assertRaisesRegex(ValueError,
1439                                'output_ragged_rank must be equal to'):
1440      RaggedTensor._from_variant(
1441          nested_batched_variant,
1442          dtype=dtypes.int32,
1443          output_ragged_rank=1,
1444          input_ragged_rank=1)
1445
1446  def _testRaggedVarientGradient(self, func, x, expected_grad):
1447    x = constant_op.constant(x)
1448    if context.executing_eagerly():
1449      with backprop.GradientTape() as t:
1450        t.watch(x)
1451        y = func(x)
1452        g = t.gradient(y, x)
1453    else:
1454      y = func(x)
1455      g = gradients_impl.gradients(ys=y, xs=x)[0]
1456    self.assertAllClose(g, expected_grad)
1457
1458  def testRaggedVariantGradients(self):
1459    def func(x):
1460      rt1 = RaggedTensor.from_row_splits(values=x, row_splits=[0, 4, 7, 8])
1461      rt2 = rt1 * [[10], [100], [1000]]
1462      v = rt2._to_variant(batched_input=False)
1463      rt3 = RaggedTensor._from_variant(v, dtype=rt2.dtype, output_ragged_rank=1)
1464      return rt3.flat_values
1465
1466    self._testRaggedVarientGradient(
1467        func,
1468        [3.0, 1.0, 4.0, 1.0, 1.0, 0.0, 2.0, 1.0],
1469        [10., 10., 10., 10., 100., 100., 100., 1000.])
1470
1471  def testRaggedVariantGradientsBatched(self):
1472    def func(x):
1473      rt1 = RaggedTensor.from_row_splits(values=x, row_splits=[0, 4, 7, 8])
1474      rt2 = rt1 * [[10], [100], [1000]]
1475      v = rt2._to_variant(batched_input=True)
1476      rt3 = RaggedTensor._from_variant(v, dtype=rt2.dtype, output_ragged_rank=1)
1477      return rt3.flat_values
1478
1479    self._testRaggedVarientGradient(
1480        func,
1481        [3.0, 1.0, 4.0, 1.0, 1.0, 0.0, 2.0, 1.0],
1482        [10., 10., 10., 10., 100., 100., 100., 1000.])
1483
1484  def testRaggedVariantGradientsBatchedAndSliced(self):
1485    def func(x, i):
1486      rt1 = RaggedTensor.from_row_splits(values=x, row_splits=[0, 4, 7, 8])
1487      rt2 = rt1 * [[10], [100], [1000]]
1488      v_slice = rt2._to_variant(batched_input=True)[i]
1489      return RaggedTensor._from_variant(v_slice, dtype=rt2.dtype,
1490                                        output_ragged_rank=0)
1491
1492    self._testRaggedVarientGradient(
1493        functools.partial(func, i=0),
1494        [3.0, 1.0, 4.0, 1.0, 1.0, 0.0, 2.0, 1.0],
1495        [10., 10., 10., 10., 0., 0., 0., 0.])
1496    self._testRaggedVarientGradient(
1497        functools.partial(func, i=1),
1498        [3.0, 1.0, 4.0, 1.0, 1.0, 0.0, 2.0, 1.0],
1499        [0., 0., 0., 0., 100., 100., 100., 0.])
1500    self._testRaggedVarientGradient(
1501        functools.partial(func, i=2),
1502        [3.0, 1.0, 4.0, 1.0, 1.0, 0.0, 2.0, 1.0],
1503        [0., 0., 0., 0., 0., 0., 0., 1000.])
1504
1505  def testRaggedVariantGradientsRaggedRank0(self):
1506    def func(x):
1507      x2 = x * 2
1508      v = gen_ragged_conversion_ops.ragged_tensor_to_variant(
1509          [], x2, batched_input=False)
1510      return RaggedTensor._from_variant(v, dtype=x2.dtype, output_ragged_rank=0)
1511
1512    self._testRaggedVarientGradient(
1513        func,
1514        [3.0, 1.0, 4.0, 1.0, 1.0, 0.0, 2.0, 1.0],
1515        [2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0])
1516
1517  def testRaggedVariantGradientsRaggedRank3(self):
1518    def func(x):
1519      x2 = x * 2
1520      rt1 = RaggedTensor.from_nested_row_splits(
1521          x2, ([0, 0, 3], [0, 2, 2, 3], [0, 4, 7, 8]))
1522      v = rt1._to_variant(batched_input=False)
1523      rt3 = RaggedTensor._from_variant(v, dtype=x2.dtype, output_ragged_rank=3)
1524      return rt3.flat_values
1525
1526    self._testRaggedVarientGradient(
1527        func,
1528        [3.0, 1.0, 4.0, 1.0, 1.0, 0.0, 2.0, 1.0],
1529        [2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0])
1530
1531  def testRaggedVariantGradientsViaMapFn(self):
1532    rt = RaggedTensor.from_row_splits(
1533        values=[3, 1.0, 4, 1, 5, 9, 2, 6], row_splits=[0, 4, 7, 8])
1534
1535    def func(x):
1536
1537      def transform_row(row):
1538        return math_ops.sqrt(
1539            math_ops.reduce_mean(math_ops.square(row * x), keepdims=True))
1540
1541      return math_ops.reduce_sum(map_fn.map_fn(transform_row, rt))
1542
1543    self._testRaggedVarientGradient(func, 3.0, 14.653377)
1544
1545  def testRaggedVariantGradientsViaMapFnReduce(self):
1546    def func(x):
1547      rt1 = RaggedTensor.from_row_splits(values=x, row_splits=[0, 4, 7, 8])
1548      return map_fn.map_fn(
1549          math_ops.reduce_max, rt1,
1550          fn_output_signature=tensor_spec.TensorSpec((), x.dtype))
1551
1552    self._testRaggedVarientGradient(
1553        func,
1554        [3.0, 1.0, 4.0, 1.0, 1.0, 0.0, 2.0, 1.0],
1555        [0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 1.0])
1556
1557  def testRaggedVariantGradientsErrors(self):
1558    if context.executing_eagerly():
1559      return
1560
1561    rt = RaggedTensor.from_row_splits([1.0, 2.0], row_splits=[0, 2, 2])
1562    v1 = rt._to_variant()
1563    v2 = array_ops.stack([array_ops.stack([v1])])
1564    y = RaggedTensor._from_variant(v2, rt.dtype, output_ragged_rank=3)
1565
1566    with self.assertRaisesRegex(
1567        ValueError, 'Unable to compute gradient: RaggedTensorToVariant '
1568        'can currently only generate 0D or 1D output.'):
1569      gradients_impl.gradients(ys=y.flat_values, xs=rt.flat_values)
1570
1571  def assertNumpyObjectTensorsRecursivelyEqual(self, a, b, msg):
1572    """Check that two numpy arrays are equal.
1573
1574    For arrays with dtype=object, check values recursively to see if a and b
1575    are equal.  (c.f. `np.array_equal`, which checks dtype=object values using
1576    object identity.)
1577
1578    Args:
1579      a: A numpy array.
1580      b: A numpy array.
1581      msg: Message to display if a != b.
1582    """
1583    if isinstance(a, np.ndarray) and a.dtype == object:
1584      self.assertEqual(a.dtype, b.dtype, msg)
1585      self.assertEqual(a.shape, b.shape, msg)
1586      self.assertLen(a, len(b), msg)
1587      for a_val, b_val in zip(a, b):
1588        self.assertNumpyObjectTensorsRecursivelyEqual(a_val, b_val, msg)
1589    else:
1590      self.assertAllEqual(a, b, msg)
1591
1592  @parameterized.named_parameters([
1593      ('Shape_2_R',
1594       [[1, 2], [3, 4, 5]],
1595       np.array([int32array([1, 2]), int32array([3, 4, 5])])),
1596      ('Shape_2_2',
1597       [[1, 2], [3, 4]],
1598       np.array([[1, 2], [3, 4]])),
1599      ('Shape_2_R_2',
1600       [[[1, 2], [3, 4]], [[5, 6]]],
1601       np.array([int32array([[1, 2], [3, 4]]), int32array([[5, 6]])])),
1602      ('Shape_3_2_R',
1603       [[[1], []], [[2, 3], [4]], [[], [5, 6, 7]]],
1604       np.array([[int32array([1]), int32array([])],
1605                 [int32array([2, 3]), int32array([4])],
1606                 [int32array([]), int32array([5, 6, 7])]])),
1607      ('Shape_0_R',
1608       ragged_factory_ops.constant_value([], ragged_rank=1, dtype=np.int32),
1609       np.zeros([0, 0], dtype=np.int32)),
1610      ('Shape_0_R_2',
1611       ragged_factory_ops.constant_value([], ragged_rank=1,
1612                                         inner_shape=(2,), dtype=np.int32),
1613       np.zeros([0, 0, 2], dtype=np.int32)),
1614  ])  # pyformat: disable
1615  def testRaggedTensorNumpy(self, rt, expected):
1616    if isinstance(rt, list):
1617      rt = ragged_factory_ops.constant(rt, dtype=dtypes.int32)
1618    else:
1619      rt = ragged_tensor.convert_to_tensor_or_ragged_tensor(rt)
1620    if context.executing_eagerly():
1621      actual = rt.numpy()
1622      self.assertNumpyObjectTensorsRecursivelyEqual(
1623          expected, actual, 'Expected %r, got %r' % (expected, actual))
1624    else:
1625      with self.assertRaisesRegex(ValueError, 'only supported in eager mode'):
1626        rt.numpy()
1627
1628  @parameterized.parameters([
1629      ([[[1, 2], [3, 4, 5]], [[6]]], 2, None),
1630      ([[[1, 2], [3, 4, 5]], [[6]]], 2, [None, None, None]),
1631      ([[[1, 2], [3, 4, 5]], [[6]]], 2, [2, None, None]),
1632      ([[[1, 2, 3], [4, 5, 6]], [[7, 8, 9]]], 1, None),
1633      ([[[1, 2, 3], [4, 5, 6]], [[7, 8, 9]]], 1, [None, None, None]),
1634      ([[[1, 2, 3], [4, 5, 6]], [[7, 8, 9]]], 1, [2, None, None]),
1635      ([[[1, 2, 3], [4, 5, 6]], [[7, 8, 9]]], 1, [2, None, 3]),
1636      ([[[1, 2, 3]]], 1, [1, 1, None]),
1637      ([[[1, 2, 3]]], 1, [1, 1, 3]),
1638  ])
1639  def testRaggedTensorSetShape(self, rt, rt_ragged_rank, shape):
1640    rt1 = ragged_factory_ops.constant(rt, ragged_rank=rt_ragged_rank)
1641    rt1._set_shape(shape)
1642    rt1.shape.assert_is_compatible_with(shape)
1643    if shape is not None:
1644      self.assertIsNot(rt1.shape.rank, None)
1645      for a, b in zip(rt1.shape, shape):
1646        if b is not None:
1647          self.assertEqual(a, b)
1648
1649  @parameterized.parameters([
1650      ([[[1, 2], [3, 4, 5]], [[6]]], 2, None),
1651      ([[[1, 2], [3, 4, 5]], [[6]]], 2, [None, None, None]),
1652      ([[[1, 2], [3, 4, 5]], [[6]]], 2, [2, None, None]),
1653      ([[[1, 2, 3], [4, 5, 6]], [[7, 8, 9]]], 1, None),
1654      ([[[1, 2, 3], [4, 5, 6]], [[7, 8, 9]]], 1, [None, None, None]),
1655      ([[[1, 2, 3], [4, 5, 6]], [[7, 8, 9]]], 1, [2, None, None]),
1656      ([[[1, 2, 3], [4, 5, 6]], [[7, 8, 9]]], 1, [2, None, 3]),
1657      ([[[1, 2, 3]]], 1, [1, 1, None]),
1658      ([[[1, 2, 3]]], 1, [1, 1, 3]),
1659  ])
1660  def testRaggedTensorSetShapeWithPlaceholders(self, rt, rt_ragged_rank, shape):
1661    rt2 = nest.map_structure(
1662        lambda x: array_ops.placeholder_with_default(x, None),
1663        ragged_factory_ops.constant(rt, ragged_rank=rt_ragged_rank),
1664        expand_composites=True)
1665    rt2._set_shape(shape)
1666    rt2.shape.assert_is_compatible_with(shape)
1667    if shape is not None:
1668      self.assertIsNot(rt2.shape.rank, None)
1669      for a, b in zip(rt2.shape, shape):
1670        if b is not None:
1671          self.assertEqual(a, b)
1672
1673  def testRaggedTensorSetShapeUniformRowLength(self):
1674    rt = [[[1], [2], [3]], [[4], [5], [6]]]
1675
1676    rt1 = RaggedTensor.from_tensor(rt, ragged_rank=1)
1677    rt1._set_shape([2, 3, 1])
1678
1679    rt2 = nest.map_structure(
1680        lambda x: array_ops.placeholder_with_default(x, None),
1681        rt1, expand_composites=True)
1682    rt2._set_shape([2, 3, 1])
1683
1684  def testRaggedTensorSetShapeInconsistentShapeError(self):
1685    rt = RaggedTensor.from_tensor([[[1], [2], [3]], [[4], [5], [6]]],
1686                                  ragged_rank=1)
1687    self.assertEqual(rt.shape.as_list(), [2, 3, 1])
1688    with self.assertRaises(ValueError):
1689      rt._set_shape([None, None, 5])
1690    with self.assertRaisesRegex(ValueError, 'Inconsistent size'):
1691      rt._set_shape([None, 5, None])
1692    with self.assertRaises(ValueError):
1693      rt._set_shape([5, None, None])
1694
1695
1696@test_util.run_all_in_graph_and_eager_modes
1697class RaggedTensorSpecTest(test_util.TensorFlowTestCase,
1698                           parameterized.TestCase):
1699
1700  def assertAllTensorsEqual(self, list1, list2):
1701    self.assertLen(list1, len(list2))
1702    for (t1, t2) in zip(list1, list2):
1703      self.assertAllEqual(t1, t2)
1704
1705  def testConstruction(self):
1706    spec1 = RaggedTensorSpec(ragged_rank=1)
1707    self.assertIsNone(spec1._shape.rank)
1708    self.assertEqual(spec1._dtype, dtypes.float32)
1709    self.assertEqual(spec1._row_splits_dtype, dtypes.int64)
1710    self.assertEqual(spec1._ragged_rank, 1)
1711
1712    self.assertIsNone(spec1.shape.rank)
1713    self.assertEqual(spec1.dtype, dtypes.float32)
1714    self.assertEqual(spec1.row_splits_dtype, dtypes.int64)
1715    self.assertEqual(spec1.ragged_rank, 1)
1716
1717    spec2 = RaggedTensorSpec(shape=[None, None, None])
1718    self.assertEqual(spec2._shape.as_list(), [None, None, None])
1719    self.assertEqual(spec2._dtype, dtypes.float32)
1720    self.assertEqual(spec2._row_splits_dtype, dtypes.int64)
1721    self.assertEqual(spec2._ragged_rank, 2)
1722
1723    with self.assertRaisesRegex(ValueError, 'Must specify ragged_rank'):
1724      RaggedTensorSpec()
1725    with self.assertRaisesRegex(TypeError, 'ragged_rank must be an int'):
1726      RaggedTensorSpec(ragged_rank=constant_op.constant(1))
1727    with self.assertRaisesRegex(ValueError,
1728                                'ragged_rank must be less than rank'):
1729      RaggedTensorSpec(ragged_rank=2, shape=[None, None])
1730
1731  def testValueType(self):
1732    spec1 = RaggedTensorSpec(ragged_rank=1)
1733    self.assertEqual(spec1.value_type, RaggedTensor)
1734    spec2 = RaggedTensorSpec(ragged_rank=0)
1735    self.assertEqual(spec2.value_type, ops.Tensor)
1736
1737  @parameterized.parameters([
1738      (RaggedTensorSpec(ragged_rank=1),
1739       (tensor_shape.TensorShape(None), dtypes.float32, 1, dtypes.int64)),
1740      (RaggedTensorSpec(shape=[5, None, None]),
1741       (tensor_shape.TensorShape([5, None, None]), dtypes.float32,
1742        2, dtypes.int64)),
1743      (RaggedTensorSpec(shape=[5, None, None], dtype=dtypes.int32),
1744       (tensor_shape.TensorShape([5, None, None]), dtypes.int32, 2,
1745        dtypes.int64)),
1746      (RaggedTensorSpec(ragged_rank=1, row_splits_dtype=dtypes.int32),
1747       (tensor_shape.TensorShape(None), dtypes.float32, 1, dtypes.int32)),
1748  ])  # pyformat: disable
1749  def testSerialize(self, rt_spec, expected):
1750    serialization = rt_spec._serialize()
1751    # TensorShape has an unconventional definition of equality, so we can't use
1752    # assertEqual directly here.  But repr() is deterministic and lossless for
1753    # the expected values, so we can use that instead.
1754    self.assertEqual(repr(serialization), repr(expected))
1755
1756  @parameterized.parameters([
1757      (RaggedTensorSpec(ragged_rank=0, shape=[5, 3]), [
1758          tensor_spec.TensorSpec([5, 3], dtypes.float32),
1759      ]),
1760      (RaggedTensorSpec(ragged_rank=1), [
1761          tensor_spec.TensorSpec(None, dtypes.float32),
1762          tensor_spec.TensorSpec([None], dtypes.int64)
1763      ]),
1764      (RaggedTensorSpec(ragged_rank=1, row_splits_dtype=dtypes.int32), [
1765          tensor_spec.TensorSpec(None, dtypes.float32),
1766          tensor_spec.TensorSpec([None], dtypes.int32),
1767      ]),
1768      (RaggedTensorSpec(ragged_rank=2), [
1769          tensor_spec.TensorSpec(None, dtypes.float32),
1770          tensor_spec.TensorSpec([None], dtypes.int64),
1771          tensor_spec.TensorSpec([None], dtypes.int64),
1772      ]),
1773      (RaggedTensorSpec(shape=[5, None, None], dtype=dtypes.string), [
1774          tensor_spec.TensorSpec([None], dtypes.string),
1775          tensor_spec.TensorSpec([6], dtypes.int64),
1776          tensor_spec.TensorSpec([None], dtypes.int64),
1777      ]),
1778  ])
1779  def testComponentSpecs(self, rt_spec, expected):
1780    self.assertEqual(rt_spec._component_specs, expected)
1781
1782  @parameterized.parameters([
1783      {
1784          'rt_spec': RaggedTensorSpec(ragged_rank=0),
1785          'rt': [1.0, 2.0, 3.0],
1786          'components': [[1.0, 2.0, 3.0]]
1787      },
1788      {
1789          'rt_spec': RaggedTensorSpec(ragged_rank=1),
1790          'rt': [[1.0, 2.0], [3.0]],
1791          'components': [[1.0, 2.0, 3.0], [0, 2, 3]]
1792      },
1793      {
1794          'rt_spec': RaggedTensorSpec(shape=[2, None, None]),
1795          'rt': [[[1.0, 2.0], [3.0]], [[], [4.0]]],
1796          'components': [[1.0, 2.0, 3.0, 4.0], [0, 2, 4], [0, 2, 3, 3, 4]]
1797      },
1798  ])
1799  def testToFromComponents(self, rt_spec, rt, components):
1800    rt = ragged_factory_ops.constant(rt)
1801    actual_components = rt_spec._to_components(rt)
1802    self.assertAllTensorsEqual(actual_components, components)
1803    rt_reconstructed = rt_spec._from_components(actual_components)
1804    self.assertAllEqual(rt, rt_reconstructed)
1805
1806  @test_util.run_v1_only('RaggedTensorValue is deprecated in v2')
1807  def testFromNumpyComponents(self):
1808    spec1 = RaggedTensorSpec(ragged_rank=1, dtype=dtypes.int32)
1809    rt1 = spec1._from_components([np.array([1, 2, 3]), np.array([0, 2, 3])])
1810    self.assertIsInstance(rt1, ragged_tensor_value.RaggedTensorValue)
1811    self.assertAllEqual(rt1, [[1, 2], [3]])
1812
1813    spec2 = RaggedTensorSpec(ragged_rank=2, dtype=dtypes.int32)
1814    rt2 = spec2._from_components(
1815        [np.array([1, 2, 3]),
1816         np.array([0, 2, 3]),
1817         np.array([0, 0, 2, 3])])
1818    self.assertIsInstance(rt2, ragged_tensor_value.RaggedTensorValue)
1819    self.assertAllEqual(rt2, [[[], [1, 2]], [[3]]])
1820
1821    spec3 = RaggedTensorSpec(ragged_rank=0, dtype=dtypes.int32)
1822    rt3 = spec3._from_components([np.array([1, 2, 3])])
1823    self.assertIsInstance(rt3, np.ndarray)
1824    self.assertAllEqual(rt3, [1, 2, 3])
1825
1826  @parameterized.parameters([
1827      RaggedTensorSpec(ragged_rank=0, shape=[5, 3]),
1828      RaggedTensorSpec(ragged_rank=1),
1829      RaggedTensorSpec(ragged_rank=1, row_splits_dtype=dtypes.int32),
1830      RaggedTensorSpec(ragged_rank=2, dtype=dtypes.string),
1831      RaggedTensorSpec(shape=[5, None, None]),
1832  ])
1833  def testFlatTensorSpecs(self, rt_spec):
1834    self.assertEqual(rt_spec._flat_tensor_specs,
1835                     [tensor_spec.TensorSpec(None, dtypes.variant)])
1836
1837  @parameterized.named_parameters([
1838      {
1839          'testcase_name': 'RaggedRank0',
1840          'rt_spec': RaggedTensorSpec(ragged_rank=0),
1841          'rt': [1.0, 2.0, 3.0],
1842      },
1843      {
1844          'testcase_name': 'RaggedRank1',
1845          'rt_spec': RaggedTensorSpec(ragged_rank=1),
1846          'rt': [[1.0, 2.0], [3.0]]
1847      },
1848      {
1849          'testcase_name': 'RaggedRank2',
1850          'rt_spec': RaggedTensorSpec(shape=[2, None, None]),
1851          'rt': [[[1.0, 2.0], [3.0]], [[], [4.0]]]
1852      },
1853  ])
1854  def testToFromTensorList(self, rt_spec, rt):
1855    rt = ragged_factory_ops.constant(rt)
1856    tensor_list = rt_spec._to_tensor_list(rt)
1857    rt_reconstructed = rt_spec._from_tensor_list(tensor_list)
1858    self.assertAllEqual(rt, rt_reconstructed)
1859
1860  @parameterized.named_parameters([
1861      # TODO(b/141789000) Test ragged_rank=0 when support is added.
1862      {
1863          'testcase_name': 'RaggedRank1',
1864          'rt_spec': RaggedTensorSpec(ragged_rank=1),
1865          'rt': [[1.0, 2.0], [3.0]]
1866      },
1867      {
1868          'testcase_name': 'RaggedRank2',
1869          'rt_spec': RaggedTensorSpec(shape=[2, None, None]),
1870          'rt': [[[1.0, 2.0], [3.0]], [[], [4.0]]]
1871      },
1872  ])
1873  def testToFromBatchedTensorList(self, rt_spec, rt):
1874    rt = ragged_factory_ops.constant(rt)
1875    tensor_list = rt_spec._to_batched_tensor_list(rt)
1876    rt_reconstructed = rt_spec._from_tensor_list(tensor_list)
1877    self.assertAllEqual(rt, rt_reconstructed)
1878    first_row = rt_spec._unbatch()._from_tensor_list(
1879        [t[0] for t in tensor_list])
1880    self.assertAllEqual(rt[0], first_row)
1881
1882  def testToFromBatchedTensorListPreservesUniformRowLengths(self):
1883    rt = RaggedTensor.from_tensor(array_ops.zeros([3, 4, 5]),
1884                                  ragged_rank=2)
1885    rt_spec = rt._type_spec
1886    tensor_list = rt_spec._to_batched_tensor_list(rt)
1887    rt_reconstructed = rt_spec._from_tensor_list(tensor_list)
1888    self.assertAllEqual(rt, rt_reconstructed)
1889    self.assertTrue(rt.shape.is_fully_defined())
1890    self.assertTrue(rt_reconstructed.shape.is_fully_defined())
1891    self.assertEqual(rt.shape.as_list(), rt_reconstructed.shape.as_list())
1892
1893  @parameterized.parameters([
1894      (RaggedTensorSpec([2, None], dtypes.float32, 1), 32,
1895       RaggedTensorSpec([32, 2, None], dtypes.float32, 2)),
1896      (RaggedTensorSpec([4, None], dtypes.float32, 1), None,
1897       RaggedTensorSpec([None, 4, None], dtypes.float32, 2)),
1898      (RaggedTensorSpec([2], dtypes.float32,
1899                        -1), 32, RaggedTensorSpec([32, 2], dtypes.float32, 0)),
1900  ])
1901  def testBatch(self, spec, batch_size, expected):
1902    self.assertEqual(spec._batch(batch_size), expected)
1903
1904  @parameterized.parameters([
1905      (RaggedTensorSpec([32, None, None], dtypes.float32, 2),
1906       RaggedTensorSpec([None, None], dtypes.float32, 1)),
1907      (RaggedTensorSpec([None, None, None], dtypes.float32, 2),
1908       RaggedTensorSpec([None, None], dtypes.float32, 1)),
1909      (RaggedTensorSpec([32, 2], dtypes.float32, 0),
1910       RaggedTensorSpec([2], dtypes.float32, -1)),
1911      (RaggedTensorSpec([32, None, 4], dtypes.float32, 1, dtypes.int32),
1912       RaggedTensorSpec([None, 4], dtypes.float32, 0, dtypes.int32)),
1913  ])  # pyformat: disable
1914  def testUnbatch(self, spec, expected):
1915    self.assertEqual(spec._unbatch(), expected)
1916
1917  def testIsCompatibleWith(self):
1918    spec1 = RaggedTensorSpec([32, None, None], dtypes.float32, 2)
1919    spec2 = RaggedTensorSpec(None, dtypes.float32, 2)
1920    spec3 = RaggedTensorSpec(None, dtypes.int32, 1)
1921    spec4 = RaggedTensorSpec([None], dtypes.int32, 0)
1922
1923    self.assertTrue(spec1.is_compatible_with(spec2))
1924    self.assertFalse(spec1.is_compatible_with(spec3))
1925    self.assertFalse(spec1.is_compatible_with(spec4))
1926    self.assertFalse(spec2.is_compatible_with(spec3))
1927    self.assertFalse(spec2.is_compatible_with(spec4))
1928    self.assertFalse(spec3.is_compatible_with(spec4))
1929    self.assertTrue(spec4.is_compatible_with(constant_op.constant([1, 2, 3])))
1930
1931
1932if __name__ == '__main__':
1933  googletest.main()
1934