1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Test utilities for tf.data functionality.""" 16from __future__ import absolute_import 17from __future__ import division 18from __future__ import print_function 19 20import re 21 22from tensorflow.python.data.ops import dataset_ops 23from tensorflow.python.data.util import nest 24from tensorflow.python.data.util import structure 25from tensorflow.python.eager import context 26from tensorflow.python.framework import combinations 27from tensorflow.python.framework import dtypes 28from tensorflow.python.framework import errors 29from tensorflow.python.framework import ops 30from tensorflow.python.framework import sparse_tensor 31from tensorflow.python.ops import array_ops 32from tensorflow.python.ops import gen_dataset_ops 33from tensorflow.python.ops import gen_experimental_dataset_ops 34from tensorflow.python.ops import tensor_array_ops 35from tensorflow.python.ops.ragged import ragged_tensor 36from tensorflow.python.platform import test 37 38 39def default_test_combinations(): 40 """Returns the default test combinations for tf.data tests.""" 41 return combinations.combine(tf_api_version=[1, 2], mode=["eager", "graph"]) 42 43 44def eager_only_combinations(): 45 """Returns the default test combinations for eager mode only tf.data tests.""" 46 return combinations.combine(tf_api_version=[1, 2], mode="eager") 47 48 49def graph_only_combinations(): 50 """Returns the default test combinations for graph mode only tf.data tests.""" 51 return combinations.combine(tf_api_version=[1, 2], mode="graph") 52 53 54def v1_only_combinations(): 55 """Returns the default test combinations for v1 only tf.data tests.""" 56 return combinations.combine(tf_api_version=1, mode=["eager", "graph"]) 57 58 59def v2_only_combinations(): 60 """Returns the default test combinations for v2 only tf.data tests.""" 61 return combinations.combine(tf_api_version=2, mode=["eager", "graph"]) 62 63 64def v2_eager_only_combinations(): 65 """Returns the default test combinations for v2 eager only tf.data tests.""" 66 return combinations.combine(tf_api_version=2, mode="eager") 67 68 69class DatasetTestBase(test.TestCase): 70 """Base class for dataset tests.""" 71 72 def assert_op_cancelled(self, op): 73 with self.assertRaises(errors.CancelledError): 74 self.evaluate(op) 75 76 def assertValuesEqual(self, expected, actual): 77 """Asserts that two values are equal.""" 78 if isinstance(expected, dict): 79 self.assertItemsEqual(list(expected.keys()), list(actual.keys())) 80 for k in expected.keys(): 81 self.assertValuesEqual(expected[k], actual[k]) 82 elif sparse_tensor.is_sparse(expected): 83 self.assertAllEqual(expected.indices, actual.indices) 84 self.assertAllEqual(expected.values, actual.values) 85 self.assertAllEqual(expected.dense_shape, actual.dense_shape) 86 else: 87 self.assertAllEqual(expected, actual) 88 89 def getNext(self, dataset, requires_initialization=False, shared_name=None): 90 """Returns a callable that returns the next element of the dataset. 91 92 Example use: 93 ```python 94 # In both graph and eager modes 95 dataset = ... 96 get_next = self.getNext(dataset) 97 result = self.evaluate(get_next()) 98 ``` 99 100 Args: 101 dataset: A dataset whose elements will be returned. 102 requires_initialization: Indicates that when the test is executed in graph 103 mode, it should use an initializable iterator to iterate through the 104 dataset (e.g. when it contains stateful nodes). Defaults to False. 105 shared_name: (Optional.) If non-empty, the returned iterator will be 106 shared under the given name across multiple sessions that share the same 107 devices (e.g. when using a remote server). 108 Returns: 109 A callable that returns the next element of `dataset`. Any `TensorArray` 110 objects `dataset` outputs are stacked. 111 """ 112 def ta_wrapper(gn): 113 def _wrapper(): 114 r = gn() 115 if isinstance(r, tensor_array_ops.TensorArray): 116 return r.stack() 117 else: 118 return r 119 return _wrapper 120 121 # Create an anonymous iterator if we are in eager-mode or are graph inside 122 # of a tf.function. 123 if context.executing_eagerly() or ops.inside_function(): 124 iterator = iter(dataset) 125 return ta_wrapper(iterator._next_internal) # pylint: disable=protected-access 126 else: 127 if requires_initialization: 128 iterator = dataset_ops.make_initializable_iterator(dataset, shared_name) 129 self.evaluate(iterator.initializer) 130 else: 131 iterator = dataset_ops.make_one_shot_iterator(dataset) 132 get_next = iterator.get_next() 133 return ta_wrapper(lambda: get_next) 134 135 def _compareOutputToExpected(self, result_values, expected_values, 136 assert_items_equal): 137 if assert_items_equal: 138 # TODO(shivaniagrawal): add support for nested elements containing sparse 139 # tensors when needed. 140 self.assertItemsEqual(result_values, expected_values) 141 return 142 for i in range(len(result_values)): 143 nest.assert_same_structure(result_values[i], expected_values[i]) 144 for result_value, expected_value in zip( 145 nest.flatten(result_values[i]), nest.flatten(expected_values[i])): 146 self.assertValuesEqual(expected_value, result_value) 147 148 def getDatasetOutput(self, dataset, requires_initialization=False): 149 get_next = self.getNext( 150 dataset, requires_initialization=requires_initialization) 151 results = [] 152 while True: 153 try: 154 results.append(self.evaluate(get_next())) 155 except errors.OutOfRangeError: 156 break 157 return results 158 159 def assertDatasetProduces(self, 160 dataset, 161 expected_output=None, 162 expected_shapes=None, 163 expected_error=None, 164 requires_initialization=False, 165 num_test_iterations=1, 166 assert_items_equal=False, 167 expected_error_iter=1): 168 """Asserts that a dataset produces the expected output / error. 169 170 Args: 171 dataset: A dataset to check for the expected output / error. 172 expected_output: A list of elements that the dataset is expected to 173 produce. 174 expected_shapes: A list of TensorShapes which is expected to match 175 output_shapes of dataset. 176 expected_error: A tuple `(type, predicate)` identifying the expected error 177 `dataset` should raise. The `type` should match the expected exception 178 type, while `predicate` should either be 1) a unary function that inputs 179 the raised exception and returns a boolean indicator of success or 2) a 180 regular expression that is expected to match the error message 181 partially. 182 requires_initialization: Indicates that when the test is executed in graph 183 mode, it should use an initializable iterator to iterate through the 184 dataset (e.g. when it contains stateful nodes). Defaults to False. 185 num_test_iterations: Number of times `dataset` will be iterated. Defaults 186 to 1. 187 assert_items_equal: Tests expected_output has (only) the same elements 188 regardless of order. 189 expected_error_iter: How many times to iterate before expecting an error, 190 if an error is expected. 191 """ 192 self.assertTrue( 193 expected_error is not None or expected_output is not None, 194 "Exactly one of expected_output or expected error should be provided.") 195 if expected_error: 196 self.assertTrue( 197 expected_output is None, 198 "Exactly one of expected_output or expected error should be provided." 199 ) 200 with self.assertRaisesWithPredicateMatch(expected_error[0], 201 expected_error[1]): 202 get_next = self.getNext( 203 dataset, requires_initialization=requires_initialization) 204 for _ in range(expected_error_iter): 205 self.evaluate(get_next()) 206 return 207 if expected_shapes: 208 self.assertEqual(expected_shapes, 209 dataset_ops.get_legacy_output_shapes(dataset)) 210 self.assertGreater(num_test_iterations, 0) 211 for _ in range(num_test_iterations): 212 get_next = self.getNext( 213 dataset, requires_initialization=requires_initialization) 214 result = [] 215 for _ in range(len(expected_output)): 216 result.append(self.evaluate(get_next())) 217 self._compareOutputToExpected(result, expected_output, assert_items_equal) 218 with self.assertRaises(errors.OutOfRangeError): 219 self.evaluate(get_next()) 220 with self.assertRaises(errors.OutOfRangeError): 221 self.evaluate(get_next()) 222 223 def assertDatasetsEqual(self, dataset1, dataset2): 224 """Checks that datasets are equal. Supports both graph and eager mode.""" 225 self.assertTrue( 226 structure.are_compatible( 227 dataset_ops.get_structure(dataset1), 228 dataset_ops.get_structure(dataset2))) 229 230 flattened_types = nest.flatten( 231 dataset_ops.get_legacy_output_types(dataset1)) 232 233 next1 = self.getNext(dataset1) 234 next2 = self.getNext(dataset2) 235 236 while True: 237 try: 238 op1 = self.evaluate(next1()) 239 except errors.OutOfRangeError: 240 with self.assertRaises(errors.OutOfRangeError): 241 self.evaluate(next2()) 242 break 243 op2 = self.evaluate(next2()) 244 245 op1 = nest.flatten(op1) 246 op2 = nest.flatten(op2) 247 assert len(op1) == len(op2) 248 for i in range(len(op1)): 249 if sparse_tensor.is_sparse(op1[i]) or ragged_tensor.is_ragged(op1[i]): 250 self.assertValuesEqual(op1[i], op2[i]) 251 elif flattened_types[i] == dtypes.string: 252 self.assertAllEqual(op1[i], op2[i]) 253 else: 254 self.assertAllClose(op1[i], op2[i]) 255 256 def assertDatasetsRaiseSameError(self, 257 dataset1, 258 dataset2, 259 exception_class, 260 replacements=None): 261 """Checks that datasets raise the same error on the first get_next call.""" 262 if replacements is None: 263 replacements = [] 264 next1 = self.getNext(dataset1) 265 next2 = self.getNext(dataset2) 266 try: 267 self.evaluate(next1()) 268 raise ValueError( 269 "Expected dataset to raise an error of type %s, but it did not." % 270 repr(exception_class)) 271 except exception_class as e: 272 expected_message = e.message 273 for old, new, count in replacements: 274 expected_message = expected_message.replace(old, new, count) 275 # Check that the first segment of the error messages are the same. 276 with self.assertRaisesRegexp(exception_class, 277 re.escape(expected_message)): 278 self.evaluate(next2()) 279 280 def structuredDataset(self, dataset_structure, shape=None, 281 dtype=dtypes.int64): 282 """Returns a singleton dataset with the given structure.""" 283 if shape is None: 284 shape = [] 285 if dataset_structure is None: 286 return dataset_ops.Dataset.from_tensors( 287 array_ops.zeros(shape, dtype=dtype)) 288 else: 289 return dataset_ops.Dataset.zip( 290 tuple([ 291 self.structuredDataset(substructure, shape, dtype) 292 for substructure in dataset_structure 293 ])) 294 295 def graphRoundTrip(self, dataset, allow_stateful=False): 296 """Converts a dataset to a graph and back.""" 297 graph = gen_dataset_ops.dataset_to_graph( 298 dataset._variant_tensor, allow_stateful=allow_stateful) # pylint: disable=protected-access 299 return dataset_ops.from_variant( 300 gen_experimental_dataset_ops.dataset_from_graph(graph), 301 dataset.element_spec) 302 303 def structuredElement(self, element_structure, shape=None, 304 dtype=dtypes.int64): 305 """Returns an element with the given structure.""" 306 if shape is None: 307 shape = [] 308 if element_structure is None: 309 return array_ops.zeros(shape, dtype=dtype) 310 else: 311 return tuple([ 312 self.structuredElement(substructure, shape, dtype) 313 for substructure in element_structure 314 ]) 315 316 def checkDeterminism(self, dataset_fn, expect_determinism, expected_elements): 317 """Tests whether a dataset produces its elements deterministically. 318 319 `dataset_fn` takes a delay_ms argument, which tells it how long to delay 320 production of the first dataset element. This gives us a way to trigger 321 out-of-order production of dataset elements. 322 323 Args: 324 dataset_fn: A function taking a delay_ms argument. 325 expect_determinism: Whether to expect deterministic ordering. 326 expected_elements: The elements expected to be produced by the dataset, 327 assuming the dataset produces elements in deterministic order. 328 """ 329 if expect_determinism: 330 dataset = dataset_fn(100) 331 actual = self.getDatasetOutput(dataset) 332 self.assertAllEqual(expected_elements, actual) 333 return 334 335 # We consider the test a success if it succeeds under any delay_ms. The 336 # delay_ms needed to observe non-deterministic ordering varies across 337 # test machines. Usually 10 or 100 milliseconds is enough, but on slow 338 # machines it could take longer. 339 for delay_ms in [10, 100, 1000, 20000]: 340 dataset = dataset_fn(delay_ms) 341 actual = self.getDatasetOutput(dataset) 342 self.assertCountEqual(expected_elements, actual) 343 for i in range(len(actual)): 344 if actual[i] != expected_elements[i]: 345 return 346 self.fail("Failed to observe nondeterministic ordering") 347