• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tests for RaggedTensor.from_tensor."""
16
17from absl.testing import parameterized
18
19import numpy as np
20
21from tensorflow.python.eager import context
22from tensorflow.python.framework import constant_op
23from tensorflow.python.framework import test_util
24from tensorflow.python.ops import array_ops
25from tensorflow.python.ops import math_ops
26from tensorflow.python.ops.ragged.ragged_tensor import RaggedTensor
27from tensorflow.python.platform import googletest
28
29
30@test_util.run_all_in_graph_and_eager_modes
31class RaggedTensorFromTensorOpTest(test_util.TensorFlowTestCase,
32                                   parameterized.TestCase):
33
34  def testDocStringExamples(self):
35    # The examples from RaggedTensor.from_tensor.__doc__.
36    dt = constant_op.constant([[5, 7, 0], [0, 3, 0], [6, 0, 0]])
37    self.assertAllEqual(
38        RaggedTensor.from_tensor(dt), [[5, 7, 0], [0, 3, 0], [6, 0, 0]])
39
40    self.assertAllEqual(
41        RaggedTensor.from_tensor(dt, lengths=[1, 0, 3]), [[5], [], [6, 0, 0]])
42
43    self.assertAllEqual(
44        RaggedTensor.from_tensor(dt, padding=0), [[5, 7], [0, 3], [6]])
45
46    dt_3d = constant_op.constant([[[5, 0], [7, 0], [0, 0]],
47                                  [[0, 0], [3, 0], [0, 0]],
48                                  [[6, 0], [0, 0], [0, 0]]])
49    self.assertAllEqual(
50        RaggedTensor.from_tensor(dt_3d, lengths=([2, 0, 3], [1, 1, 2, 0, 1])),
51        [[[5], [7]], [], [[6, 0], [], [0]]])
52
53  @parameterized.parameters(
54      # 2D test cases, no length or padding.
55      {
56          'tensor': [[]],
57          'expected': [[]],
58          'expected_shape': [1, 0],
59      },
60      {
61          'tensor': [[1]],
62          'expected': [[1]],
63          'expected_shape': [1, 1],
64      },
65      {
66          'tensor': [[1, 2]],
67          'expected': [[1, 2]],
68          'expected_shape': [1, 2],
69      },
70      {
71          'tensor': [[1], [2], [3]],
72          'expected': [[1], [2], [3]],
73          'expected_shape': [3, 1],
74      },
75      {
76          'tensor': [[1, 2, 3], [4, 5, 6], [7, 8, 9]],
77          'expected': [[1, 2, 3], [4, 5, 6], [7, 8, 9]],
78          'expected_shape': [3, 3],
79      },
80      # 3D test cases, no length or padding
81      {
82          'tensor': [[[]]],
83          'expected': [[[]]],
84          'expected_shape': [1, 1, 0],
85      },
86      {
87          'tensor': [[[]]],
88          'expected': [[[]]],
89          'ragged_rank': 1,
90          'expected_shape': [1, 1, 0],
91      },
92      {
93          'tensor': [[[1]]],
94          'expected': [[[1]]],
95          'expected_shape': [1, 1, 1],
96      },
97      {
98          'tensor': [[[1, 2]]],
99          'expected': [[[1, 2]]],
100          'expected_shape': [1, 1, 2],
101      },
102      {
103          'tensor': [[[1, 2], [3, 4]]],
104          'expected': [[[1, 2], [3, 4]]],
105          'expected_shape': [1, 2, 2],
106      },
107      {
108          'tensor': [[[1, 2]], [[3, 4]], [[5, 6]], [[7, 8]]],
109          'expected': [[[1, 2]], [[3, 4]], [[5, 6]], [[7, 8]]],
110          'expected_shape': [4, 1, 2],
111      },
112      {
113          'tensor': [[[1], [2]], [[3], [4]], [[5], [6]], [[7], [8]]],
114          'expected': [[[1], [2]], [[3], [4]], [[5], [6]], [[7], [8]]],
115          'expected_shape': [4, 2, 1],
116      },
117      # 2D test cases, with length
118      {
119          'tensor': [[1]],
120          'lengths': [1],
121          'expected': [[1]],
122          'expected_shape': [1, None],
123      },
124      {
125          'tensor': [[1]],
126          'lengths': [0],
127          'expected': [[]],
128          'expected_shape': [1, None],
129      },
130      {
131          'tensor': [[1, 2, 3], [4, 5, 6], [7, 8, 9]],
132          'lengths': [0, 1, 2],
133          'expected': [[], [4], [7, 8]],
134          'expected_shape': [3, None],
135      },
136      {
137          'tensor': [[1, 2, 3], [4, 5, 6], [7, 8, 9]],
138          'lengths': [0, 0, 0],
139          'expected': [[], [], []],
140          'expected_shape': [3, None],
141      },
142      {
143          'tensor': [[1, 2], [3, 4]],
144          'lengths': [2, 2],
145          'expected': [[1, 2], [3, 4]],
146          'expected_shape': [2, None],
147      },
148      {
149          'tensor': [[1, 2], [3, 4]],
150          'lengths': [7, 8],  # lengths > ncols: truncated to ncols
151          'expected': [[1, 2], [3, 4]],
152          'expected_shape': [2, None],
153      },
154      {
155          'tensor': [[1, 2], [3, 4]],
156          'lengths': [-2, -1],  # lengths < 0: treated as zero
157          'expected': [[], []],
158          'expected_shape': [2, None],
159      },
160      # 3D test cases, with length
161      {
162          'tensor': [[[1, 2], [3, 4]], [[5, 6], [7, 8]]],
163          'lengths': [0, 0],
164          'expected': [[], []],
165          'expected_shape': [2, None, 2],
166      },
167      {
168          'tensor': [[[1, 2], [3, 4]], [[5, 6], [7, 8]]],
169          'lengths': [1, 2],
170          'expected': [[[1, 2]], [[5, 6], [7, 8]]],
171          'expected_shape': [2, None, 2],
172      },
173      {
174          'tensor': [[[1, 2], [3, 4]], [[5, 6], [7, 8]]],
175          'lengths': [2, 2],
176          'expected': [[[1, 2], [3, 4]], [[5, 6], [7, 8]]],
177          'expected_shape': [2, None, 2],
178      },
179      # 2D test cases, with padding
180      {
181          'tensor': [[1]],
182          'padding': 0,
183          'expected': [[1]],
184          'expected_shape': [1, None],
185      },
186      {
187          'tensor': [[0]],
188          'padding': 0,
189          'expected': [[]],
190          'expected_shape': [1, None],
191      },
192      {
193          'tensor': [[0, 1]],
194          'padding': 0,
195          'expected': [[0, 1]],
196          'expected_shape': [1, None],
197      },
198      {
199          'tensor': [[1, 0]],
200          'padding': 0,
201          'expected': [[1]],
202          'expected_shape': [1, None],
203      },
204      {
205          'tensor': [[1, 0, 1, 0, 0, 1, 0, 0]],
206          'padding': 0,
207          'expected': [[1, 0, 1, 0, 0, 1]],
208          'expected_shape': [1, None],
209      },
210      {
211          'tensor': [[3, 7, 0, 0], [2, 0, 0, 0], [5, 0, 0, 0]],
212          'padding': 0,
213          'expected': [[3, 7], [2], [5]],
214          'expected_shape': [3, None],
215      },
216      # 3D test cases, with padding
217      {
218          'tensor': [[[1]]],
219          'padding': [0],
220          'expected': [[[1]]],
221          'expected_shape': [1, None, 1],
222      },
223      {
224          'tensor': [[[0]]],
225          'padding': [0],
226          'expected': [[]],
227          'expected_shape': [1, None, 1],
228      },
229      {
230          'tensor': [[[0, 0], [1, 2]], [[3, 4], [0, 0]]],
231          'padding': [0, 0],
232          'expected': [[[0, 0], [1, 2]], [[3, 4]]],
233          'expected_shape': [2, None, 2],
234      },
235      # 4D test cases, with padding
236      {
237          'tensor': [
238              [[[1, 2], [3, 4]], [[0, 0], [0, 0]], [[0, 0], [0, 0]]],
239              [[[0, 0], [0, 0]], [[5, 6], [7, 8]], [[0, 0], [0, 0]]],
240              [[[0, 0], [0, 0]], [[0, 0], [0, 0]], [[0, 0], [0, 0]]]
241          ],
242          'padding': [[0, 0], [0, 0]],
243          'expected': [
244              [[[1, 2], [3, 4]]],
245              [[[0, 0], [0, 0]], [[5, 6], [7, 8]]],
246              []
247          ],
248          'expected_shape': [3, None, 2, 2],
249      },
250      # 3D test cases, with ragged_rank=2.
251      {
252          'tensor': [[[1, 0], [2, 3]], [[0, 0], [4, 0]]],
253          'ragged_rank': 2,
254          'expected': [[[1, 0], [2, 3]], [[0, 0], [4, 0]]],
255          'expected_shape': [2, 2, 2],
256      },
257      {
258          'tensor': [[[1, 2], [3, 4]], [[5, 6], [7, 8]]],
259          'ragged_rank': 2,
260          'lengths': [2, 0, 2, 1],
261          'expected': [[[1, 2], []], [[5, 6], [7]]],
262          'expected_shape': [2, 2, None],
263      },
264      {
265          'tensor': [[[1, 0], [2, 3]], [[0, 0], [4, 0]]],
266          'ragged_rank': 2,
267          'padding': 0,
268          'expected': [[[1], [2, 3]], [[], [4]]],
269          'expected_shape': [2, 2, None],
270      },
271      # 4D test cases, with ragged_rank>1
272      {
273          'tensor': [[[[1, 0], [2, 3]], [[0, 0], [4, 0]]],
274                     [[[5, 6], [7, 0]], [[0, 8], [0, 0]]]],
275          'ragged_rank': 2,
276          'expected': [[[[1, 0], [2, 3]], [[0, 0], [4, 0]]],
277                       [[[5, 6], [7, 0]], [[0, 8], [0, 0]]]],
278          'expected_shape': [2, 2, 2, 2],
279      },
280      {
281          'tensor': [[[[1, 0], [2, 3]], [[0, 0], [4, 0]]],
282                     [[[5, 6], [7, 0]], [[0, 8], [0, 0]]]],
283          'ragged_rank': 3,
284          'expected': [[[[1, 0], [2, 3]], [[0, 0], [4, 0]]],
285                       [[[5, 6], [7, 0]], [[0, 8], [0, 0]]]],
286          'expected_shape': [2, 2, 2, 2],
287      },
288      {
289          'tensor': [[[[1, 0], [2, 3]], [[0, 0], [4, 0]]],
290                     [[[5, 6], [7, 0]], [[0, 8], [0, 0]]]],
291          'ragged_rank': 2,
292          'padding': [0, 0],
293          'expected': [[[[1, 0], [2, 3]], [[0, 0], [4, 0]]],
294                       [[[5, 6], [7, 0]], [[0, 8]]]],
295          'expected_shape': [2, 2, None, 2],
296      },
297      {
298          'tensor': [[[[1, 0], [2, 3]], [[0, 0], [4, 0]]],
299                     [[[5, 6], [7, 0]], [[0, 8], [0, 0]]]],
300          'lengths': ([2, 2], [1, 2, 2, 1]),
301          'expected': [[[[1, 0]], [[0, 0], [4, 0]]],
302                       [[[5, 6], [7, 0]], [[0, 8]]]],
303          'ragged_rank': 2,
304          'use_ragged_rank': False,  # lengths contains nested_row_lengths.
305          'expected_shape': [2, None, None, 2],
306      },
307      {
308          'tensor': [[[[1, 0], [2, 3]], [[0, 0], [4, 0]]],
309                     [[[5, 6], [7, 0]], [[0, 8], [0, 0]]]],
310          'lengths': [[2, 2], [1, 2, 2, 1]],
311          'expected': [[[[1, 0]], [[0, 0], [4, 0]]],
312                       [[[5, 6], [7, 0]], [[0, 8]]]],
313          'ragged_rank': 2,
314          'use_ragged_rank': False,  # lengths contains nested_row_lengths.
315          'expected_shape': [2, None, None, 2],
316      },
317      {
318          'tensor': [[[[1, 0], [2, 3]], [[0, 0], [4, 0]]],
319                     [[[5, 6], [7, 0]], [[0, 8], [0, 0]]]],
320          'ragged_rank': 3,
321          'padding': 0,
322          'expected': [[[[1], [2, 3]], [[], [4]]],
323                       [[[5, 6], [7]], [[0, 8], []]]],
324          'expected_shape': [2, 2, 2, None],
325      },
326      {
327          'tensor': [[[[1, 0], [2, 3]], [[0, 0], [4, 0]]],
328                     [[[5, 6], [7, 0]], [[0, 8], [0, 0]]]],
329          'lengths': ([2, 2], [2, 2, 2, 2], [1, 2, 0, 1, 2, 1, 2, 0]),
330          'expected': [[[[1], [2, 3]], [[], [4]]],
331                       [[[5, 6], [7]], [[0, 8], []]]],
332          'ragged_rank': 3,
333          'use_ragged_rank': False,  # lengths contains nested_row_lengths.
334          'expected_shape': [2, None, None, None],
335      },
336      {
337          'tensor': [[[[1, 0], [2, 3]], [[0, 0], [4, 0]]],
338                     [[[5, 6], [7, 0]], [[0, 8], [0, 0]]]],
339          'lengths': [[2, 2], [2, 2, 2, 2], [1, 2, 0, 1, 2, 1, 2, 0]],
340          'expected': [[[[1], [2, 3]], [[], [4]]],
341                       [[[5, 6], [7]], [[0, 8], []]]],
342          'ragged_rank': 3,
343          'use_ragged_rank': False,  # lengths contains nested_row_lengths.
344          'expected_shape': [2, None, None, None],
345      },
346  )  # pyformat: disable
347  def testRaggedFromTensor(self,
348                           tensor,
349                           expected,
350                           lengths=None,
351                           padding=None,
352                           ragged_rank=1,
353                           use_ragged_rank=True,
354                           expected_shape=None):
355    dt = constant_op.constant(tensor)
356    if use_ragged_rank:
357      rt = RaggedTensor.from_tensor(dt, lengths, padding, ragged_rank)
358    else:
359      rt = RaggedTensor.from_tensor(dt, lengths, padding)
360    self.assertEqual(type(rt), RaggedTensor)
361    self.assertEqual(rt.ragged_rank, ragged_rank)
362    self.assertTrue(
363        dt.shape.is_compatible_with(rt.shape),
364        '%s is incompatible with %s' % (dt.shape, rt.shape))
365    if expected_shape is not None:
366      self.assertEqual(rt.shape.as_list(), expected_shape)
367    self.assertAllEqual(rt, expected)
368    self.assertAllEqual(rt, RaggedTensor.from_nested_row_splits(
369        rt.flat_values, rt.nested_row_splits, validate=True))
370
371  def testHighDimensions(self):
372    # Use distinct prime numbers for all dimension shapes in this test, so
373    # we can see any errors that are caused by mixing up dimension sizes.
374    dt = array_ops.reshape(
375        math_ops.range(3 * 5 * 7 * 11 * 13 * 17), [3, 5, 7, 11, 13, 17])
376    for ragged_rank in range(1, 4):
377      rt = RaggedTensor.from_tensor(dt, ragged_rank=ragged_rank)
378      self.assertEqual(type(rt), RaggedTensor)
379      self.assertEqual(rt.ragged_rank, ragged_rank)
380      self.assertTrue(
381          dt.shape.is_compatible_with(rt.shape),
382          '%s is incompatible with %s' % (dt.shape, rt.shape))
383      self.assertAllEqual(rt, self.evaluate(dt).tolist())
384      self.assertAllEqual(rt, RaggedTensor.from_nested_row_splits(
385          rt.flat_values, rt.nested_row_splits, validate=True))
386
387  @parameterized.parameters(
388      # With no padding or lengths
389      {
390          'dt_shape': [0, 0],
391          'expected': []
392      },
393      {
394          'dt_shape': [0, 3],
395          'expected': []
396      },
397      {
398          'dt_shape': [3, 0],
399          'expected': [[], [], []]
400      },
401      {
402          'dt_shape': [0, 2, 3],
403          'expected': []
404      },
405      {
406          'dt_shape': [1, 0, 0],
407          'expected': [[]]
408      },
409      {
410          'dt_shape': [2, 0, 3],
411          'expected': [[], []]
412      },
413      {
414          'dt_shape': [2, 3, 0],
415          'expected': [[[], [], []], [[], [], []]]
416      },
417      {
418          'dt_shape': [2, 3, 0, 1],
419          'expected': [[[], [], []], [[], [], []]]
420      },
421      {
422          'dt_shape': [2, 3, 1, 0],
423          'expected': [[[[]], [[]], [[]]], [[[]], [[]], [[]]]]
424      },
425      # With padding
426      {
427          'dt_shape': [0, 0],
428          'padding': 0,
429          'expected': []
430      },
431      {
432          'dt_shape': [0, 3],
433          'padding': 0,
434          'expected': []
435      },
436      {
437          'dt_shape': [3, 0],
438          'padding': 0,
439          'expected': [[], [], []]
440      },
441      {
442          'dt_shape': [0, 2, 3],
443          'padding': [0, 0, 0],
444          'expected': []
445      },
446      {
447          'dt_shape': [2, 0, 3],
448          'padding': [0, 0, 0],
449          'expected': [[], []]
450      },
451      {
452          'dt_shape': [2, 3, 0],
453          'padding': [],
454          'expected': [[], []]
455      },
456      # With lengths
457      {
458          'dt_shape': [0, 0],
459          'lengths': [],
460          'expected': []
461      },
462      {
463          'dt_shape': [0, 3],
464          'lengths': [],
465          'expected': []
466      },
467      {
468          'dt_shape': [3, 0],
469          'lengths': [0, 0, 0],
470          'expected': [[], [], []]
471      },
472      {
473          'dt_shape': [3, 0],
474          'lengths': [2, 3, 4],  # lengths > ncols: truncated to ncols
475          'expected': [[], [], []]
476      },
477      {
478          'dt_shape': [0, 2, 3],
479          'lengths': [],
480          'expected': []
481      },
482      {
483          'dt_shape': [2, 0, 3],
484          'lengths': [0, 0],
485          'expected': [[], []]
486      },
487      {
488          'dt_shape': [2, 3, 0],
489          'lengths': [0, 0],
490          'expected': [[], []]
491      },
492  )
493  def testEmpty(self, dt_shape, expected, lengths=None, padding=None):
494    dt = array_ops.zeros(dt_shape)
495    for ragged_rank in range(1, len(dt_shape) - 1):
496      rt = RaggedTensor.from_tensor(dt, lengths, padding, ragged_rank)
497      self.assertEqual(type(rt), RaggedTensor)
498      self.assertEqual(rt.ragged_rank, ragged_rank)
499      self.assertTrue(dt.shape.is_compatible_with(rt.shape))
500      self.assertAllEqual(rt, expected)
501      self.assertAllEqual(rt, RaggedTensor.from_nested_row_splits(
502          rt.flat_values, rt.nested_row_splits, validate=True))
503
504  @parameterized.named_parameters([
505      {
506          'testcase_name': '2D_UnknownRank',
507          'tensor': [[1, 2], [3, 4]],
508          'tensor_shape': None,
509      },
510      {
511          'testcase_name': '2D_Shape_None_None',
512          'tensor': [[1, 2], [3, 4]],
513          'tensor_shape': [None, None],
514      },
515      {
516          'testcase_name': '2D_Shape_2_None',
517          'tensor': [[1, 2], [3, 4]],
518          'tensor_shape': [2, None],
519      },
520      {
521          'testcase_name': '2D_Shape_None_2',
522          'tensor': [[1, 2], [3, 4]],
523          'tensor_shape': [None, 2],
524      },
525      {
526          'testcase_name': '4D_UnknownRank',
527          'tensor': np.ones([4, 3, 2, 1]),
528          'tensor_shape': None,
529      },
530      {
531          'testcase_name': '4D_Shape_None_None_None_None',
532          'tensor': np.ones([4, 3, 2, 1]),
533          'tensor_shape': [None, None, None, None],
534      },
535      {
536          'tensor': np.ones([4, 3, 2, 1]),
537          'tensor_shape': [4, None, None, 1],
538          'testcase_name': '4D_Shape_4_None_None_1',
539      },
540  ])
541  def testPartialShapes(self, tensor, tensor_shape, shape=None,
542                        expected=None):
543    if expected is None:
544      expected = tensor
545
546    if context.executing_eagerly():
547      return  # static shapes are always fully defined in eager mode.
548
549    dt = constant_op.constant(tensor)
550    for ragged_rank in range(1, len(dt.shape) - 1):
551      dt_placeholder = array_ops.placeholder_with_default(tensor, tensor_shape)
552      rt = RaggedTensor.from_tensor(dt_placeholder, ragged_rank=ragged_rank)
553      self.assertIsInstance(rt, RaggedTensor)
554      self.assertEqual(rt.ragged_rank, ragged_rank)
555      self.assertTrue(
556          dt.shape.is_compatible_with(rt.shape),
557          '%s is incompatible with %s' % (dt.shape, rt.shape))
558      if shape is not None:
559        self.assertEqual(rt.shape.as_list(), shape)
560      self.assertAllEqual(rt, expected.tolist())
561      self.assertAllEqual(rt, RaggedTensor.from_nested_row_splits(
562          rt.flat_values, rt.nested_row_splits, validate=True))
563
564  @parameterized.parameters(
565      {
566          'tensor': [[1]],
567          'lengths': [0],
568          'padding':
569              0,
570          'error': (ValueError,
571                    'Specify argument `lengths` or `padding`, but not both.')
572      },
573      {
574          'tensor': [[1]],
575          'lengths': [0.5],
576          'error': (
577              TypeError,
578              r'Argument `tensor` \(name\: lengths\) must be of type integer.*')
579      },
580      {
581          'tensor': [[1, 2, 3]],
582          'lengths': [[1], [1]],
583          'error': (ValueError, r'Shape \(1, 3\) must have rank at least 3')
584      },
585      {
586          'tensor': [[1]],
587          'padding': 'a',
588          'error': (TypeError, '.*')
589      },
590      {
591          'tensor': [[1]],
592          'padding': [1],
593          'error': (ValueError, r'Shapes \(1,\) and \(\) are incompatible')
594      },
595      {
596          'tensor': [[[1]]],
597          'padding': 1,
598          'error': (ValueError, r'Shapes \(\) and \(1,\) are incompatible')
599      },
600      {
601          'tensor': [[1]],
602          'ragged_rank':
603              'bad',
604          'error': (TypeError,
605                    r'Argument `ragged_rank` must be an int. Received bad.')
606      },
607      {
608          'tensor': [[1]],
609          'ragged_rank':
610              0,
611          'error':
612              (ValueError,
613               r'Argument `ragged_rank` must be greater than 0. Received 0.')
614      },
615      {
616          'tensor': [[1]],
617          'ragged_rank':
618              -1,
619          'error':
620              (ValueError,
621               r'Argument `ragged_rank` must be greater than 0. Received -1.')
622      },
623      {
624          'tensor': [[[[1, 0], [2, 3]], [[0, 0], [4, 0]]],
625                     [[[5, 6], [7, 0]], [[0, 8], [0, 0]]]],
626          'lengths': ([2, 2], [2, 2, 2, 2]),
627          'ragged_rank':
628              3,
629          'error':
630              (ValueError,
631               r'If Argument `lengths` is a tuple of row_lengths, argument '
632               r'`ragged_rank` must be len\(lengths\): 2. Received '
633               r'ragged_rank: 3.')
634      },
635  )
636  def testErrors(self,
637                 tensor,
638                 lengths=None,
639                 padding=None,
640                 ragged_rank=1,
641                 error=None):
642    dt = constant_op.constant(tensor)
643    self.assertRaisesRegex(error[0], error[1], RaggedTensor.from_tensor, dt,
644                           lengths, padding, ragged_rank)
645
646
647if __name__ == '__main__':
648  googletest.main()
649