• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15
16# pylint: disable=invalid-name
17"""Test utils for tensorflow RaggedTensors."""
18
19from __future__ import absolute_import
20from __future__ import division
21from __future__ import print_function
22
23import numpy as np
24
25from tensorflow.python.framework import ops
26from tensorflow.python.framework import test_util
27from tensorflow.python.ops.ragged import ragged_tensor
28from tensorflow.python.ops.ragged import ragged_tensor_value
29
30
31class RaggedTensorTestCase(test_util.TensorFlowTestCase):
32  """Base class for RaggedTensor test cases."""
33
34  def _GetPyList(self, a):
35    """Converts a to a nested python list."""
36    if isinstance(a, ragged_tensor.RaggedTensor):
37      return self.evaluate(a).to_list()
38    elif isinstance(a, ops.Tensor):
39      a = self.evaluate(a)
40      return a.tolist() if isinstance(a, np.ndarray) else a
41    elif isinstance(a, np.ndarray):
42      return a.tolist()
43    elif isinstance(a, ragged_tensor_value.RaggedTensorValue):
44      return a.to_list()
45    else:
46      return np.array(a).tolist()
47
48  def assertRaggedEqual(self, a, b):
49    """Asserts that two potentially ragged tensors are equal."""
50    a_list = self._GetPyList(a)
51    b_list = self._GetPyList(b)
52    self.assertEqual(a_list, b_list)
53
54    if not (isinstance(a, (list, tuple)) or isinstance(b, (list, tuple))):
55      a_ragged_rank = a.ragged_rank if ragged_tensor.is_ragged(a) else 0
56      b_ragged_rank = b.ragged_rank if ragged_tensor.is_ragged(b) else 0
57      self.assertEqual(a_ragged_rank, b_ragged_rank)
58
59  def assertRaggedAlmostEqual(self, a, b, places=7):
60    a_list = self._GetPyList(a)
61    b_list = self._GetPyList(b)
62    self.assertNestedListAlmostEqual(a_list, b_list, places, context='value')
63
64    if not (isinstance(a, (list, tuple)) or isinstance(b, (list, tuple))):
65      a_ragged_rank = a.ragged_rank if ragged_tensor.is_ragged(a) else 0
66      b_ragged_rank = b.ragged_rank if ragged_tensor.is_ragged(b) else 0
67      self.assertEqual(a_ragged_rank, b_ragged_rank)
68
69  def assertNestedListAlmostEqual(self, a, b, places=7, context='value'):
70    self.assertEqual(type(a), type(b))
71    if isinstance(a, (list, tuple)):
72      self.assertLen(a, len(b), 'Length differs for %s' % context)
73      for i in range(len(a)):
74        self.assertNestedListAlmostEqual(a[i], b[i], places,
75                                         '%s[%s]' % (context, i))
76    else:
77      self.assertAlmostEqual(
78          a, b, places,
79          '%s != %s within %s places at %s' % (a, b, places, context))
80
81  def eval_to_list(self, tensor):
82    value = self.evaluate(tensor)
83    if ragged_tensor.is_ragged(value):
84      return value.to_list()
85    elif isinstance(value, np.ndarray):
86      return value.tolist()
87    else:
88      return value
89
90  def _eval_tensor(self, tensor):
91    if ragged_tensor.is_ragged(tensor):
92      return ragged_tensor_value.RaggedTensorValue(
93          self._eval_tensor(tensor.values),
94          self._eval_tensor(tensor.row_splits))
95    else:
96      return test_util.TensorFlowTestCase._eval_tensor(self, tensor)
97