1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15 16# pylint: disable=invalid-name 17"""Test utils for tensorflow RaggedTensors.""" 18 19from __future__ import absolute_import 20from __future__ import division 21from __future__ import print_function 22 23import numpy as np 24 25from tensorflow.python.framework import ops 26from tensorflow.python.framework import test_util 27from tensorflow.python.ops.ragged import ragged_tensor 28from tensorflow.python.ops.ragged import ragged_tensor_value 29 30 31class RaggedTensorTestCase(test_util.TensorFlowTestCase): 32 """Base class for RaggedTensor test cases.""" 33 34 def _GetPyList(self, a): 35 """Converts a to a nested python list.""" 36 if isinstance(a, ragged_tensor.RaggedTensor): 37 return self.evaluate(a).to_list() 38 elif isinstance(a, ops.Tensor): 39 a = self.evaluate(a) 40 return a.tolist() if isinstance(a, np.ndarray) else a 41 elif isinstance(a, np.ndarray): 42 return a.tolist() 43 elif isinstance(a, ragged_tensor_value.RaggedTensorValue): 44 return a.to_list() 45 else: 46 return np.array(a).tolist() 47 48 def assertRaggedEqual(self, a, b): 49 """Asserts that two potentially ragged tensors are equal.""" 50 a_list = self._GetPyList(a) 51 b_list = self._GetPyList(b) 52 self.assertEqual(a_list, b_list) 53 54 if not (isinstance(a, (list, tuple)) or isinstance(b, (list, tuple))): 55 a_ragged_rank = a.ragged_rank if ragged_tensor.is_ragged(a) else 0 56 b_ragged_rank = b.ragged_rank if ragged_tensor.is_ragged(b) else 0 57 self.assertEqual(a_ragged_rank, b_ragged_rank) 58 59 def assertRaggedAlmostEqual(self, a, b, places=7): 60 a_list = self._GetPyList(a) 61 b_list = self._GetPyList(b) 62 self.assertNestedListAlmostEqual(a_list, b_list, places, context='value') 63 64 if not (isinstance(a, (list, tuple)) or isinstance(b, (list, tuple))): 65 a_ragged_rank = a.ragged_rank if ragged_tensor.is_ragged(a) else 0 66 b_ragged_rank = b.ragged_rank if ragged_tensor.is_ragged(b) else 0 67 self.assertEqual(a_ragged_rank, b_ragged_rank) 68 69 def assertNestedListAlmostEqual(self, a, b, places=7, context='value'): 70 self.assertEqual(type(a), type(b)) 71 if isinstance(a, (list, tuple)): 72 self.assertLen(a, len(b), 'Length differs for %s' % context) 73 for i in range(len(a)): 74 self.assertNestedListAlmostEqual(a[i], b[i], places, 75 '%s[%s]' % (context, i)) 76 else: 77 self.assertAlmostEqual( 78 a, b, places, 79 '%s != %s within %s places at %s' % (a, b, places, context)) 80 81 def eval_to_list(self, tensor): 82 value = self.evaluate(tensor) 83 if ragged_tensor.is_ragged(value): 84 return value.to_list() 85 elif isinstance(value, np.ndarray): 86 return value.tolist() 87 else: 88 return value 89 90 def _eval_tensor(self, tensor): 91 if ragged_tensor.is_ragged(tensor): 92 return ragged_tensor_value.RaggedTensorValue( 93 self._eval_tensor(tensor.values), 94 self._eval_tensor(tensor.row_splits)) 95 else: 96 return test_util.TensorFlowTestCase._eval_tensor(self, tensor) 97