1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Test utilities for tf.data functionality.""" 16from __future__ import absolute_import 17from __future__ import division 18from __future__ import print_function 19 20import re 21 22from tensorflow.python import tf2 23from tensorflow.python.data.ops import dataset_ops 24from tensorflow.python.data.util import nest 25from tensorflow.python.eager import context 26from tensorflow.python.framework import dtypes 27from tensorflow.python.framework import errors 28from tensorflow.python.framework import sparse_tensor 29from tensorflow.python.ops import array_ops 30from tensorflow.python.platform import test 31 32 33class DatasetTestBase(test.TestCase): 34 """Base class for dataset tests.""" 35 36 @classmethod 37 def setUpClass(cls): 38 if tf2.enabled(): 39 dataset_ops.Dataset = dataset_ops.DatasetV2 40 else: 41 dataset_ops.Dataset = dataset_ops.DatasetV1 42 43 def assertSparseValuesEqual(self, a, b): 44 """Asserts that two SparseTensors/SparseTensorValues are equal.""" 45 self.assertAllEqual(a.indices, b.indices) 46 self.assertAllEqual(a.values, b.values) 47 self.assertAllEqual(a.dense_shape, b.dense_shape) 48 49 def getNext(self, dataset, requires_initialization=False): 50 """Returns a callable that returns the next element of the dataset. 51 52 Example use: 53 ```python 54 # In both graph and eager modes 55 dataset = ... 56 get_next = self.getNext(dataset) 57 result = self.evaluate(get_next()) 58 ``` 59 60 Args: 61 dataset: A dataset whose elements will be returned. 62 requires_initialization: Indicates that when the test is executed in graph 63 mode, it should use an initializable iterator to iterate through the 64 dataset (e.g. when it contains stateful nodes). Defaults to False. 65 Returns: 66 A callable that returns the next element of `dataset`. 67 """ 68 if context.executing_eagerly(): 69 iterator = iter(dataset) 70 return iterator._next_internal # pylint: disable=protected-access 71 else: 72 if requires_initialization: 73 iterator = dataset_ops.make_initializable_iterator(dataset) 74 self.evaluate(iterator.initializer) 75 else: 76 iterator = dataset_ops.make_one_shot_iterator(dataset) 77 get_next = iterator.get_next() 78 return lambda: get_next 79 80 def _compareOutputToExpected(self, result_values, expected_values, 81 assert_items_equal): 82 if assert_items_equal: 83 # TODO(shivaniagrawal): add support for nested elements containing sparse 84 # tensors when needed. 85 self.assertItemsEqual(result_values, expected_values) 86 return 87 for i in range(len(result_values)): 88 nest.assert_same_structure(result_values[i], expected_values[i]) 89 for result_value, expected_value in zip( 90 nest.flatten(result_values[i]), nest.flatten(expected_values[i])): 91 if sparse_tensor.is_sparse(result_value): 92 self.assertSparseValuesEqual(result_value, expected_value) 93 else: 94 self.assertAllEqual(result_value, expected_value) 95 96 def assertDatasetProduces(self, 97 dataset, 98 expected_output=None, 99 expected_shapes=None, 100 expected_error=None, 101 requires_initialization=False, 102 num_test_iterations=1, 103 assert_items_equal=False, 104 expected_error_iter=1): 105 """Asserts that a dataset produces the expected output / error. 106 107 Args: 108 dataset: A dataset to check for the expected output / error. 109 expected_output: A list of elements that the dataset is expected to 110 produce. 111 expected_shapes: A list of TensorShapes which is expected to match 112 output_shapes of dataset. 113 expected_error: A tuple `(type, predicate)` identifying the expected error 114 `dataset` should raise. The `type` should match the expected exception 115 type, while `predicate` should either be 1) a unary function that inputs 116 the raised exception and returns a boolean indicator of success or 2) a 117 regular expression that is expected to match the error message 118 partially. 119 requires_initialization: Indicates that when the test is executed in graph 120 mode, it should use an initializable iterator to iterate through the 121 dataset (e.g. when it contains stateful nodes). Defaults to False. 122 num_test_iterations: Number of times `dataset` will be iterated. Defaults 123 to 2. 124 assert_items_equal: Tests expected_output has (only) the same elements 125 regardless of order. 126 expected_error_iter: How many times to iterate before expecting an error, 127 if an error is expected. 128 """ 129 self.assertTrue( 130 expected_error is not None or expected_output is not None, 131 "Exactly one of expected_output or expected error should be provided.") 132 if expected_error: 133 self.assertTrue( 134 expected_output is None, 135 "Exactly one of expected_output or expected error should be provided." 136 ) 137 with self.assertRaisesWithPredicateMatch(expected_error[0], 138 expected_error[1]): 139 get_next = self.getNext( 140 dataset, requires_initialization=requires_initialization) 141 for _ in range(expected_error_iter): 142 self.evaluate(get_next()) 143 return 144 if expected_shapes: 145 self.assertEqual(expected_shapes, 146 dataset_ops.get_legacy_output_shapes(dataset)) 147 self.assertGreater(num_test_iterations, 0) 148 for _ in range(num_test_iterations): 149 get_next = self.getNext( 150 dataset, requires_initialization=requires_initialization) 151 result = [] 152 for _ in range(len(expected_output)): 153 result.append(self.evaluate(get_next())) 154 self._compareOutputToExpected(result, expected_output, assert_items_equal) 155 with self.assertRaises(errors.OutOfRangeError): 156 self.evaluate(get_next()) 157 with self.assertRaises(errors.OutOfRangeError): 158 self.evaluate(get_next()) 159 160 def assertDatasetsEqual(self, dataset1, dataset2): 161 """Checks that datasets are equal. Supports both graph and eager mode.""" 162 self.assertTrue(dataset_ops.get_structure(dataset1).is_compatible_with( 163 dataset_ops.get_structure(dataset2))) 164 self.assertTrue(dataset_ops.get_structure(dataset2).is_compatible_with( 165 dataset_ops.get_structure(dataset1))) 166 flattened_types = nest.flatten( 167 dataset_ops.get_legacy_output_types(dataset1)) 168 169 next1 = self.getNext(dataset1) 170 next2 = self.getNext(dataset2) 171 while True: 172 try: 173 op1 = self.evaluate(next1()) 174 except errors.OutOfRangeError: 175 with self.assertRaises(errors.OutOfRangeError): 176 self.evaluate(next2()) 177 break 178 op2 = self.evaluate(next2()) 179 180 op1 = nest.flatten(op1) 181 op2 = nest.flatten(op2) 182 assert len(op1) == len(op2) 183 for i in range(len(op1)): 184 if sparse_tensor.is_sparse(op1[i]): 185 self.assertSparseValuesEqual(op1[i], op2[i]) 186 elif flattened_types[i] == dtypes.string: 187 self.assertAllEqual(op1[i], op2[i]) 188 else: 189 self.assertAllClose(op1[i], op2[i]) 190 191 def assertDatasetsRaiseSameError(self, 192 dataset1, 193 dataset2, 194 exception_class, 195 replacements=None): 196 """Checks that datasets raise the same error on the first get_next call.""" 197 if replacements is None: 198 replacements = [] 199 next1 = self.getNext(dataset1) 200 next2 = self.getNext(dataset2) 201 try: 202 self.evaluate(next1()) 203 raise ValueError( 204 "Expected dataset to raise an error of type %s, but it did not." % 205 repr(exception_class)) 206 except exception_class as e: 207 expected_message = e.message 208 for old, new, count in replacements: 209 expected_message = expected_message.replace(old, new, count) 210 # Check that the first segment of the error messages are the same. 211 with self.assertRaisesRegexp(exception_class, 212 re.escape(expected_message)): 213 self.evaluate(next2()) 214 215 def structuredDataset(self, structure, shape=None, dtype=dtypes.int64): 216 """Returns a singleton dataset with the given structure.""" 217 if shape is None: 218 shape = [] 219 if structure is None: 220 return dataset_ops.Dataset.from_tensors( 221 array_ops.zeros(shape, dtype=dtype)) 222 else: 223 return dataset_ops.Dataset.zip( 224 tuple([ 225 self.structuredDataset(substructure, shape, dtype) 226 for substructure in structure 227 ])) 228 229 def structuredElement(self, structure, shape=None, dtype=dtypes.int64): 230 """Returns an element with the given structure.""" 231 if shape is None: 232 shape = [] 233 if structure is None: 234 return array_ops.zeros(shape, dtype=dtype) 235 else: 236 return tuple([ 237 self.structuredElement(substructure, shape, dtype) 238 for substructure in structure 239 ]) 240